mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-27 17:35:07 +01:00
Merge intermediate changes
This commit is contained in:
commit
f12ca56e93
47 changed files with 1874 additions and 744 deletions
6
.github/workflows/build.yml
vendored
6
.github/workflows/build.yml
vendored
|
@ -11,6 +11,10 @@ on:
|
|||
- Dockerfile
|
||||
- docker-compose.yml
|
||||
- .github/workflows/build.yml
|
||||
workflow_dispatch:
|
||||
|
||||
env:
|
||||
DOCKER_IMAGE_TAG: ${{ github.ref == 'refs/heads/master' && 'latest' || github.ref }}
|
||||
|
||||
jobs:
|
||||
build:
|
||||
|
@ -36,6 +40,6 @@ jobs:
|
|||
context: .
|
||||
file: Dockerfile
|
||||
push: true
|
||||
tags: ghcr.io/${{ github.repository }}:latest
|
||||
tags: ghcr.io/${{ github.repository }}:${{ env.DOCKER_IMAGE_TAG }}
|
||||
build-args: |
|
||||
PORT=8000
|
6
.github/workflows/release.yml
vendored
6
.github/workflows/release.yml
vendored
|
@ -1,15 +1,15 @@
|
|||
name: release
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- v*
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
version:
|
||||
description: 'Version Number'
|
||||
required: true
|
||||
type: string
|
||||
push:
|
||||
tags:
|
||||
- v*
|
||||
|
||||
jobs:
|
||||
publish:
|
||||
|
|
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -13,3 +13,4 @@ src/.data
|
|||
/dist/
|
||||
/khoj_assistant.egg-info/
|
||||
/config/khoj*.yml
|
||||
.pytest_cache
|
||||
|
|
|
@ -4,7 +4,7 @@ LABEL org.opencontainers.image.source https://github.com/debanjum/khoj
|
|||
|
||||
# Install System Dependencies
|
||||
RUN apt-get update -y && \
|
||||
apt-get -y install libimage-exiftool-perl
|
||||
apt-get -y install libimage-exiftool-perl python3-pyqt5
|
||||
|
||||
# Copy Application to Container
|
||||
COPY . /app
|
||||
|
|
|
@ -2,7 +2,6 @@
|
|||
[![build](https://github.com/debanjum/khoj/actions/workflows/build.yml/badge.svg)](https://github.com/debanjum/khoj/actions/workflows/build.yml)
|
||||
[![test](https://github.com/debanjum/khoj/actions/workflows/test.yml/badge.svg)](https://github.com/debanjum/khoj/actions/workflows/test.yml)
|
||||
[![publish](https://github.com/debanjum/khoj/actions/workflows/publish.yml/badge.svg)](https://github.com/debanjum/khoj/actions/workflows/publish.yml)
|
||||
[![release](https://github.com/debanjum/khoj/actions/workflows/release.yml/badge.svg)](https://github.com/debanjum/khoj/actions/workflows/release.yml)
|
||||
|
||||
*A natural language search engine for your personal notes, transactions and images*
|
||||
|
||||
|
@ -107,7 +106,7 @@ pip install --upgrade khoj-assistant
|
|||
## Troubleshoot
|
||||
|
||||
- Symptom: Errors out complaining about Tensors mismatch, null etc
|
||||
- Mitigation: Disable `image` search on the desktop GUI
|
||||
- Mitigation: Disable `image` search using the desktop GUI
|
||||
- Symptom: Errors out with \"Killed\" in error message in Docker
|
||||
- Fix: Increase RAM available to Docker Containers in Docker Settings
|
||||
- Refer: [StackOverflow Solution](https://stackoverflow.com/a/50770267), [Configure Resources on Docker for Mac](https://docs.docker.com/desktop/mac/#resources)
|
||||
|
@ -125,14 +124,14 @@ pip install --upgrade khoj-assistant
|
|||
|
||||
- Semantic search using the bi-encoder is fairly fast at \<50 ms
|
||||
- Reranking using the cross-encoder is slower at \<2s on 15 results. Tweak `top_k` to tradeoff speed for accuracy of results
|
||||
- Applying explicit filters is very slow currently at \~6s. This is because the filters are rudimentary. Considerable speed-ups can be achieved using indexes etc
|
||||
- Filters in query (e.g by file, word or date) usually add \<20ms to query latency
|
||||
|
||||
### Indexing performance
|
||||
|
||||
- Indexing is more strongly impacted by the size of the source data
|
||||
- Indexing 100K+ line corpus of notes takes 6 minutes
|
||||
- Indexing 100K+ line corpus of notes takes about 10 minutes
|
||||
- Indexing 4000+ images takes about 15 minutes and more than 8Gb of RAM
|
||||
- Once <https://github.com/debanjum/khoj/issues/36> is implemented, it should only take this long on first run
|
||||
- Note: *It should only take this long on the first run* as the index is incrementally updated
|
||||
|
||||
### Miscellaneous
|
||||
|
||||
|
|
|
@ -28,4 +28,4 @@ services:
|
|||
- ./tests/data/embeddings/:/data/embeddings/
|
||||
- ./tests/data/models/:/data/models/
|
||||
# Use 0.0.0.0 to explicitly set the host ip for the service on the container. https://pythonspeed.com/articles/docker-connection-refused/
|
||||
command: config/khoj_docker.yml --host="0.0.0.0" --port=8000 -vv
|
||||
command: --no-gui -c=config/khoj_docker.yml --host="0.0.0.0" --port=8000 -vv
|
||||
|
|
2
setup.py
2
setup.py
|
@ -7,7 +7,7 @@ this_directory = Path(__file__).parent
|
|||
|
||||
setup(
|
||||
name='khoj-assistant',
|
||||
version='0.1.6',
|
||||
version='0.1.10',
|
||||
description="A natural language search engine for your personal notes, transactions and images",
|
||||
long_description=(this_directory / "Readme.md").read_text(encoding="utf-8"),
|
||||
long_description_content_type="text/markdown",
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
# System Packages
|
||||
import sys
|
||||
import logging
|
||||
|
||||
# External Packages
|
||||
import torch
|
||||
import json
|
||||
|
||||
# Internal Packages
|
||||
|
@ -13,8 +13,14 @@ from src.processor.org_mode.org_to_jsonl import org_to_jsonl
|
|||
from src.search_type import image_search, text_search
|
||||
from src.utils.config import SearchType, SearchModels, ProcessorConfigModel, ConversationProcessorConfigModel
|
||||
from src.utils import state
|
||||
from src.utils.helpers import get_absolute_path
|
||||
from src.utils.helpers import LRU, resolve_absolute_path
|
||||
from src.utils.rawconfig import FullConfig, ProcessorConfig
|
||||
from src.search_filter.date_filter import DateFilter
|
||||
from src.search_filter.word_filter import WordFilter
|
||||
from src.search_filter.file_filter import FileFilter
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def configure_server(args, required=False):
|
||||
|
@ -28,27 +34,42 @@ def configure_server(args, required=False):
|
|||
state.config = args.config
|
||||
|
||||
# Initialize the search model from Config
|
||||
state.model = configure_search(state.model, state.config, args.regenerate, verbose=state.verbose)
|
||||
state.model = configure_search(state.model, state.config, args.regenerate)
|
||||
|
||||
# Initialize Processor from Config
|
||||
state.processor_config = configure_processor(args.config.processor, verbose=state.verbose)
|
||||
state.processor_config = configure_processor(args.config.processor)
|
||||
|
||||
|
||||
def configure_search(model: SearchModels, config: FullConfig, regenerate: bool, t: SearchType = None, verbose: int = 0):
|
||||
def configure_search(model: SearchModels, config: FullConfig, regenerate: bool, t: SearchType = None):
|
||||
# Initialize Org Notes Search
|
||||
if (t == SearchType.Org or t == None) and config.content_type.org:
|
||||
# Extract Entries, Generate Notes Embeddings
|
||||
model.orgmode_search = text_search.setup(org_to_jsonl, config.content_type.org, search_config=config.search_type.asymmetric, regenerate=regenerate, verbose=verbose)
|
||||
model.orgmode_search = text_search.setup(
|
||||
org_to_jsonl,
|
||||
config.content_type.org,
|
||||
search_config=config.search_type.asymmetric,
|
||||
regenerate=regenerate,
|
||||
filters=[DateFilter(), WordFilter(), FileFilter()])
|
||||
|
||||
# Initialize Org Music Search
|
||||
if (t == SearchType.Music or t == None) and config.content_type.music:
|
||||
# Extract Entries, Generate Music Embeddings
|
||||
model.music_search = text_search.setup(org_to_jsonl, config.content_type.music, search_config=config.search_type.asymmetric, regenerate=regenerate, verbose=verbose)
|
||||
model.music_search = text_search.setup(
|
||||
org_to_jsonl,
|
||||
config.content_type.music,
|
||||
search_config=config.search_type.asymmetric,
|
||||
regenerate=regenerate,
|
||||
filters=[DateFilter(), WordFilter()])
|
||||
|
||||
# Initialize Markdown Search
|
||||
if (t == SearchType.Markdown or t == None) and config.content_type.markdown:
|
||||
# Extract Entries, Generate Markdown Embeddings
|
||||
model.markdown_search = text_search.setup(markdown_to_jsonl, config.content_type.markdown, search_config=config.search_type.asymmetric, regenerate=regenerate, verbose=verbose)
|
||||
model.markdown_search = text_search.setup(
|
||||
markdown_to_jsonl,
|
||||
config.content_type.markdown,
|
||||
search_config=config.search_type.asymmetric,
|
||||
regenerate=regenerate,
|
||||
filters=[DateFilter(), WordFilter(), FileFilter()])
|
||||
|
||||
# Initialize Panchayat Search
|
||||
if (t == SearchType.Panchayat or t == None) and config.content_type.panchayat:
|
||||
|
@ -58,17 +79,28 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool,
|
|||
# Initialize Ledger Search
|
||||
if (t == SearchType.Ledger or t == None) and config.content_type.ledger:
|
||||
# Extract Entries, Generate Ledger Embeddings
|
||||
model.ledger_search = text_search.setup(beancount_to_jsonl, config.content_type.ledger, search_config=config.search_type.symmetric, regenerate=regenerate, verbose=verbose)
|
||||
model.ledger_search = text_search.setup(
|
||||
beancount_to_jsonl,
|
||||
config.content_type.ledger,
|
||||
search_config=config.search_type.symmetric,
|
||||
regenerate=regenerate,
|
||||
filters=[DateFilter(), WordFilter(), FileFilter()])
|
||||
|
||||
# Initialize Image Search
|
||||
if (t == SearchType.Image or t == None) and config.content_type.image:
|
||||
# Extract Entries, Generate Image Embeddings
|
||||
model.image_search = image_search.setup(config.content_type.image, search_config=config.search_type.image, regenerate=regenerate, verbose=verbose)
|
||||
model.image_search = image_search.setup(
|
||||
config.content_type.image,
|
||||
search_config=config.search_type.image,
|
||||
regenerate=regenerate)
|
||||
|
||||
# Invalidate Query Cache
|
||||
state.query_cache = LRU()
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def configure_processor(processor_config: ProcessorConfig, verbose: int):
|
||||
def configure_processor(processor_config: ProcessorConfig):
|
||||
if not processor_config:
|
||||
return
|
||||
|
||||
|
@ -76,24 +108,20 @@ def configure_processor(processor_config: ProcessorConfig, verbose: int):
|
|||
|
||||
# Initialize Conversation Processor
|
||||
if processor_config.conversation:
|
||||
processor.conversation = configure_conversation_processor(processor_config.conversation, verbose)
|
||||
processor.conversation = configure_conversation_processor(processor_config.conversation)
|
||||
|
||||
return processor
|
||||
|
||||
|
||||
def configure_conversation_processor(conversation_processor_config, verbose: int):
|
||||
conversation_processor = ConversationProcessorConfigModel(conversation_processor_config, verbose)
|
||||
def configure_conversation_processor(conversation_processor_config):
|
||||
conversation_processor = ConversationProcessorConfigModel(conversation_processor_config)
|
||||
conversation_logfile = resolve_absolute_path(conversation_processor.conversation_logfile)
|
||||
|
||||
conversation_logfile = conversation_processor.conversation_logfile
|
||||
if conversation_processor.verbose:
|
||||
print('INFO:\tLoading conversation logs from disk...')
|
||||
|
||||
if conversation_logfile.expanduser().absolute().is_file():
|
||||
if conversation_logfile.is_file():
|
||||
# Load Metadata Logs from Conversation Logfile
|
||||
with open(get_absolute_path(conversation_logfile), 'r') as f:
|
||||
with conversation_logfile.open('r') as f:
|
||||
conversation_processor.meta_log = json.load(f)
|
||||
|
||||
print('INFO:\tConversation logs loaded from disk.')
|
||||
logger.info('Conversation logs loaded from disk.')
|
||||
else:
|
||||
# Initialize Conversation Logs
|
||||
conversation_processor.meta_log = {}
|
||||
|
|
|
@ -92,7 +92,7 @@ class MainWindow(QtWidgets.QMainWindow):
|
|||
search_type_layout = QtWidgets.QVBoxLayout(search_type_settings)
|
||||
enable_search_type = SearchCheckBox(f"Search {search_type.name}", search_type)
|
||||
# Add file browser to set input files for given search type
|
||||
input_files = FileBrowser(file_input_text, search_type, current_content_files)
|
||||
input_files = FileBrowser(file_input_text, search_type, current_content_files or [])
|
||||
|
||||
# Set enabled/disabled based on checkbox state
|
||||
enable_search_type.setChecked(current_content_files is not None and len(current_content_files) > 0)
|
||||
|
|
|
@ -6,9 +6,10 @@ from PyQt6 import QtGui, QtWidgets
|
|||
|
||||
# Internal Packages
|
||||
from src.utils import constants, state
|
||||
from src.interface.desktop.main_window import MainWindow
|
||||
|
||||
|
||||
def create_system_tray(gui: QtWidgets.QApplication, main_window: QtWidgets.QMainWindow):
|
||||
def create_system_tray(gui: QtWidgets.QApplication, main_window: MainWindow):
|
||||
"""Create System Tray with Menu. Menu contain options to
|
||||
1. Open Search Page on the Web Interface
|
||||
2. Open App Configuration Screen
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
;; Author: Debanjum Singh Solanky <debanjum@gmail.com>
|
||||
;; Description: Natural, Incremental Search for your Second Brain
|
||||
;; Keywords: search, org-mode, outlines, markdown, beancount, ledger, image
|
||||
;; Version: 0.1.6
|
||||
;; Version: 0.1.9
|
||||
;; Package-Requires: ((emacs "27.1"))
|
||||
;; URL: http://github.com/debanjum/khoj/interface/emacs
|
||||
|
||||
|
|
|
@ -61,12 +61,26 @@
|
|||
}
|
||||
|
||||
function search(rerank=false) {
|
||||
query = document.getElementById("query").value;
|
||||
// Extract required fields for search from form
|
||||
query = document.getElementById("query").value.trim();
|
||||
type = document.getElementById("type").value;
|
||||
console.log(query, type);
|
||||
results_count = document.getElementById("results-count").value || 6;
|
||||
console.log(`Query: ${query}, Type: ${type}`);
|
||||
|
||||
// Short circuit on empty query
|
||||
if (query.length === 0)
|
||||
return;
|
||||
|
||||
// If set query field in url query param on rerank
|
||||
if (rerank)
|
||||
setQueryFieldInUrl(query);
|
||||
|
||||
// Generate Backend API URL to execute Search
|
||||
url = type === "image"
|
||||
? `/search?q=${query}&t=${type}&n=6`
|
||||
: `/search?q=${query}&t=${type}&n=6&r=${rerank}`;
|
||||
? `/search?q=${encodeURIComponent(query)}&t=${type}&n=${results_count}`
|
||||
: `/search?q=${encodeURIComponent(query)}&t=${type}&n=${results_count}&r=${rerank}`;
|
||||
|
||||
// Execute Search and Render Results
|
||||
fetch(url)
|
||||
.then(response => response.json())
|
||||
.then(data => {
|
||||
|
@ -78,9 +92,9 @@
|
|||
});
|
||||
}
|
||||
|
||||
function regenerate() {
|
||||
function updateIndex() {
|
||||
type = document.getElementById("type").value;
|
||||
fetch(`/regenerate?t=${type}`)
|
||||
fetch(`/reload?t=${type}`)
|
||||
.then(response => response.json())
|
||||
.then(data => {
|
||||
console.log(data);
|
||||
|
@ -89,7 +103,7 @@
|
|||
});
|
||||
}
|
||||
|
||||
function incremental_search(event) {
|
||||
function incrementalSearch(event) {
|
||||
type = document.getElementById("type").value;
|
||||
// Search with reranking on 'Enter'
|
||||
if (event.key === 'Enter') {
|
||||
|
@ -121,10 +135,33 @@
|
|||
});
|
||||
}
|
||||
|
||||
function setTypeFieldInUrl(type) {
|
||||
var url = new URL(window.location.href);
|
||||
url.searchParams.set("t", type.value);
|
||||
window.history.pushState({}, "", url.href);
|
||||
}
|
||||
|
||||
function setCountFieldInUrl(results_count) {
|
||||
var url = new URL(window.location.href);
|
||||
url.searchParams.set("n", results_count.value);
|
||||
window.history.pushState({}, "", url.href);
|
||||
}
|
||||
|
||||
function setQueryFieldInUrl(query) {
|
||||
var url = new URL(window.location.href);
|
||||
url.searchParams.set("q", query);
|
||||
window.history.pushState({}, "", url.href);
|
||||
}
|
||||
|
||||
window.onload = function () {
|
||||
// Dynamically populate type dropdown based on enabled search types and type passed as URL query parameter
|
||||
populate_type_dropdown();
|
||||
|
||||
// Set results count field with value passed in URL query parameters, if any.
|
||||
var results_count = new URLSearchParams(window.location.search).get("n");
|
||||
if (results_count)
|
||||
document.getElementById("results-count").value = results_count;
|
||||
|
||||
// Fill query field with value passed in URL query parameters, if any.
|
||||
var query_via_url = new URLSearchParams(window.location.search).get("q");
|
||||
if (query_via_url)
|
||||
|
@ -136,14 +173,17 @@
|
|||
<h1>Khoj</h1>
|
||||
|
||||
<!--Add Text Box To Enter Query, Trigger Incremental Search OnChange -->
|
||||
<input type="text" id="query" onkeyup=incremental_search(event) autofocus="autofocus" placeholder="What is the meaning of life?">
|
||||
<input type="text" id="query" onkeyup=incrementalSearch(event) autofocus="autofocus" placeholder="What is the meaning of life?">
|
||||
|
||||
<div id="options">
|
||||
<!--Add Dropdown to Select Query Type -->
|
||||
<select id="type"></select>
|
||||
<select id="type" onchange="setTypeFieldInUrl(this)"></select>
|
||||
|
||||
<!--Add Button To Regenerate -->
|
||||
<button id="regenerate" onclick="regenerate()">Regenerate</button>
|
||||
<button id="update" onclick="updateIndex()">Update</button>
|
||||
|
||||
<!--Add Results Count Input To Set Results Count -->
|
||||
<input type="number" id="results-count" min="1" max="100" value="6" placeholder="results count" onchange="setCountFieldInUrl(this)">
|
||||
</div>
|
||||
|
||||
<!-- Section to Render Results -->
|
||||
|
@ -194,7 +234,7 @@
|
|||
#options {
|
||||
padding: 0;
|
||||
display: grid;
|
||||
grid-template-columns: 1fr 1fr;
|
||||
grid-template-columns: 1fr 1fr minmax(70px, 0.5fr);
|
||||
}
|
||||
#options > * {
|
||||
padding: 15px;
|
||||
|
@ -202,10 +242,10 @@
|
|||
border: 1px solid #ccc;
|
||||
}
|
||||
#options > select {
|
||||
margin-right: 5px;
|
||||
margin-right: 10px;
|
||||
}
|
||||
#options > button {
|
||||
margin-left: 5px;
|
||||
margin-right: 10px;
|
||||
}
|
||||
|
||||
#query {
|
||||
|
|
56
src/main.py
56
src/main.py
|
@ -2,8 +2,13 @@
|
|||
import os
|
||||
import signal
|
||||
import sys
|
||||
import logging
|
||||
import warnings
|
||||
from platform import system
|
||||
|
||||
# Ignore non-actionable warnings
|
||||
warnings.filterwarnings("ignore", message=r'snapshot_download.py has been made private', category=FutureWarning)
|
||||
|
||||
# External Packages
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
|
@ -25,6 +30,34 @@ app = FastAPI()
|
|||
app.mount("/static", StaticFiles(directory=constants.web_directory), name="static")
|
||||
app.include_router(router)
|
||||
|
||||
logger = logging.getLogger('src')
|
||||
|
||||
|
||||
class CustomFormatter(logging.Formatter):
|
||||
|
||||
blue = "\x1b[1;34m"
|
||||
green = "\x1b[1;32m"
|
||||
grey = "\x1b[38;20m"
|
||||
yellow = "\x1b[33;20m"
|
||||
red = "\x1b[31;20m"
|
||||
bold_red = "\x1b[31;1m"
|
||||
reset = "\x1b[0m"
|
||||
format_str = "%(levelname)s: %(asctime)s: %(name)s | %(message)s"
|
||||
|
||||
FORMATS = {
|
||||
logging.DEBUG: blue + format_str + reset,
|
||||
logging.INFO: green + format_str + reset,
|
||||
logging.WARNING: yellow + format_str + reset,
|
||||
logging.ERROR: red + format_str + reset,
|
||||
logging.CRITICAL: bold_red + format_str + reset
|
||||
}
|
||||
|
||||
def format(self, record):
|
||||
log_fmt = self.FORMATS.get(record.levelno)
|
||||
formatter = logging.Formatter(log_fmt)
|
||||
return formatter.format(record)
|
||||
|
||||
|
||||
def run():
|
||||
# Turn Tokenizers Parallelism Off. App does not support it.
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = 'false'
|
||||
|
@ -34,6 +67,29 @@ def run():
|
|||
args = cli(state.cli_args)
|
||||
set_state(args)
|
||||
|
||||
# Create app directory, if it doesn't exist
|
||||
state.config_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Setup Logger
|
||||
if args.verbose == 0:
|
||||
logger.setLevel(logging.WARN)
|
||||
elif args.verbose == 1:
|
||||
logger.setLevel(logging.INFO)
|
||||
elif args.verbose >= 2:
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
# Set Log Format
|
||||
ch = logging.StreamHandler()
|
||||
ch.setFormatter(CustomFormatter())
|
||||
logger.addHandler(ch)
|
||||
|
||||
# Set Log File
|
||||
fh = logging.FileHandler(state.config_file.parent / 'khoj.log')
|
||||
fh.setLevel(logging.DEBUG)
|
||||
logger.addHandler(fh)
|
||||
|
||||
logger.info("Starting Khoj...")
|
||||
|
||||
if args.no_gui:
|
||||
# Start Server
|
||||
configure_server(args, required=True)
|
||||
|
|
|
@ -2,108 +2,127 @@
|
|||
|
||||
# Standard Packages
|
||||
import json
|
||||
import argparse
|
||||
import pathlib
|
||||
import glob
|
||||
import re
|
||||
import logging
|
||||
import time
|
||||
|
||||
# Internal Packages
|
||||
from src.utils.helpers import get_absolute_path, is_none_or_empty
|
||||
from src.utils.helpers import get_absolute_path, is_none_or_empty, mark_entries_for_update
|
||||
from src.utils.constants import empty_escape_sequences
|
||||
from src.utils.jsonl import dump_jsonl, compress_jsonl_data
|
||||
from src.utils.rawconfig import TextContentConfig
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Define Functions
|
||||
def beancount_to_jsonl(beancount_files, beancount_file_filter, output_file, verbose=0):
|
||||
def beancount_to_jsonl(config: TextContentConfig, previous_entries=None):
|
||||
# Extract required fields from config
|
||||
beancount_files, beancount_file_filter, output_file = config.input_files, config.input_filter, config.compressed_jsonl
|
||||
|
||||
# Input Validation
|
||||
if is_none_or_empty(beancount_files) and is_none_or_empty(beancount_file_filter):
|
||||
print("At least one of beancount-files or beancount-file-filter is required to be specified")
|
||||
exit(1)
|
||||
|
||||
# Get Beancount Files to Process
|
||||
beancount_files = get_beancount_files(beancount_files, beancount_file_filter, verbose)
|
||||
beancount_files = get_beancount_files(beancount_files, beancount_file_filter)
|
||||
|
||||
# Extract Entries from specified Beancount files
|
||||
entries = extract_beancount_entries(beancount_files)
|
||||
start = time.time()
|
||||
current_entries = convert_transactions_to_maps(*extract_beancount_transactions(beancount_files))
|
||||
end = time.time()
|
||||
logger.debug(f"Parse transactions from Beancount files into dictionaries: {end - start} seconds")
|
||||
|
||||
# Identify, mark and merge any new entries with previous entries
|
||||
start = time.time()
|
||||
if not previous_entries:
|
||||
entries_with_ids = list(enumerate(current_entries))
|
||||
else:
|
||||
entries_with_ids = mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger)
|
||||
end = time.time()
|
||||
logger.debug(f"Identify new or updated transaction: {end - start} seconds")
|
||||
|
||||
# Process Each Entry from All Notes Files
|
||||
jsonl_data = convert_beancount_entries_to_jsonl(entries, verbose=verbose)
|
||||
start = time.time()
|
||||
entries = list(map(lambda entry: entry[1], entries_with_ids))
|
||||
jsonl_data = convert_transaction_maps_to_jsonl(entries)
|
||||
|
||||
# Compress JSONL formatted Data
|
||||
if output_file.suffix == ".gz":
|
||||
compress_jsonl_data(jsonl_data, output_file, verbose=verbose)
|
||||
compress_jsonl_data(jsonl_data, output_file)
|
||||
elif output_file.suffix == ".jsonl":
|
||||
dump_jsonl(jsonl_data, output_file, verbose=verbose)
|
||||
dump_jsonl(jsonl_data, output_file)
|
||||
end = time.time()
|
||||
logger.debug(f"Write transactions to JSONL file: {end - start} seconds")
|
||||
|
||||
return entries
|
||||
return entries_with_ids
|
||||
|
||||
|
||||
def get_beancount_files(beancount_files=None, beancount_file_filter=None, verbose=0):
|
||||
def get_beancount_files(beancount_files=None, beancount_file_filters=None):
|
||||
"Get Beancount files to process"
|
||||
absolute_beancount_files, filtered_beancount_files = set(), set()
|
||||
if beancount_files:
|
||||
absolute_beancount_files = {get_absolute_path(beancount_file)
|
||||
for beancount_file
|
||||
in beancount_files}
|
||||
if beancount_file_filter:
|
||||
filtered_beancount_files = set(glob.glob(get_absolute_path(beancount_file_filter)))
|
||||
if beancount_file_filters:
|
||||
filtered_beancount_files = {
|
||||
filtered_file
|
||||
for beancount_file_filter in beancount_file_filters
|
||||
for filtered_file in glob.glob(get_absolute_path(beancount_file_filter))
|
||||
}
|
||||
|
||||
all_beancount_files = absolute_beancount_files | filtered_beancount_files
|
||||
all_beancount_files = sorted(absolute_beancount_files | filtered_beancount_files)
|
||||
|
||||
files_with_non_beancount_extensions = {beancount_file
|
||||
files_with_non_beancount_extensions = {
|
||||
beancount_file
|
||||
for beancount_file
|
||||
in all_beancount_files
|
||||
if not beancount_file.endswith(".bean") and not beancount_file.endswith(".beancount")}
|
||||
if not beancount_file.endswith(".bean") and not beancount_file.endswith(".beancount")
|
||||
}
|
||||
if any(files_with_non_beancount_extensions):
|
||||
print(f"[Warning] There maybe non beancount files in the input set: {files_with_non_beancount_extensions}")
|
||||
|
||||
if verbose > 0:
|
||||
print(f'Processing files: {all_beancount_files}')
|
||||
logger.info(f'Processing files: {all_beancount_files}')
|
||||
|
||||
return all_beancount_files
|
||||
|
||||
|
||||
def extract_beancount_entries(beancount_files):
|
||||
def extract_beancount_transactions(beancount_files):
|
||||
"Extract entries from specified Beancount files"
|
||||
|
||||
# Initialize Regex for extracting Beancount Entries
|
||||
transaction_regex = r'^\n?\d{4}-\d{2}-\d{2} [\*|\!] '
|
||||
empty_newline = f'^[{empty_escape_sequences}]*$'
|
||||
empty_newline = f'^[\n\r\t\ ]*$'
|
||||
|
||||
entries = []
|
||||
transaction_to_file_map = []
|
||||
for beancount_file in beancount_files:
|
||||
with open(beancount_file) as f:
|
||||
ledger_content = f.read()
|
||||
entries.extend([entry.strip(empty_escape_sequences)
|
||||
transactions_per_file = [entry.strip(empty_escape_sequences)
|
||||
for entry
|
||||
in re.split(empty_newline, ledger_content, flags=re.MULTILINE)
|
||||
if re.match(transaction_regex, entry)])
|
||||
|
||||
return entries
|
||||
if re.match(transaction_regex, entry)]
|
||||
transaction_to_file_map += zip(transactions_per_file, [beancount_file]*len(transactions_per_file))
|
||||
entries.extend(transactions_per_file)
|
||||
return entries, dict(transaction_to_file_map)
|
||||
|
||||
|
||||
def convert_beancount_entries_to_jsonl(entries, verbose=0):
|
||||
"Convert each Beancount transaction to JSON and collate as JSONL"
|
||||
jsonl = ''
|
||||
def convert_transactions_to_maps(entries: list[str], transaction_to_file_map) -> list[dict]:
|
||||
"Convert each Beancount transaction into a dictionary"
|
||||
entry_maps = []
|
||||
for entry in entries:
|
||||
entry_dict = {'compiled': entry, 'raw': entry}
|
||||
# Convert Dictionary to JSON and Append to JSONL string
|
||||
jsonl += f'{json.dumps(entry_dict, ensure_ascii=False)}\n'
|
||||
entry_maps.append({'compiled': entry, 'raw': entry, 'file': f'{transaction_to_file_map[entry]}'})
|
||||
|
||||
if verbose > 0:
|
||||
print(f"Converted {len(entries)} to jsonl format")
|
||||
logger.info(f"Converted {len(entries)} transactions to dictionaries")
|
||||
|
||||
return jsonl
|
||||
return entry_maps
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Setup Argument Parser
|
||||
parser = argparse.ArgumentParser(description="Map Beancount transactions into (compressed) JSONL format")
|
||||
parser.add_argument('--output-file', '-o', type=pathlib.Path, required=True, help="Output file for (compressed) JSONL formatted transactions. Expected file extensions: jsonl or jsonl.gz")
|
||||
parser.add_argument('--input-files', '-i', nargs='*', help="List of beancount files to process")
|
||||
parser.add_argument('--input-filter', type=str, default=None, help="Regex filter for beancount files to process")
|
||||
parser.add_argument('--verbose', '-v', action='count', default=0, help="Show verbose conversion logs, Default: 0")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Map transactions in beancount files to (compressed) JSONL formatted file
|
||||
beancount_to_jsonl(args.input_files, args.input_filter, args.output_file, args.verbose)
|
||||
def convert_transaction_maps_to_jsonl(entries: list[dict]) -> str:
|
||||
"Convert each Beancount transaction dictionary to JSON and collate as JSONL"
|
||||
return ''.join([f'{json.dumps(entry_dict, ensure_ascii=False)}\n' for entry_dict in entries])
|
||||
|
|
|
@ -2,51 +2,78 @@
|
|||
|
||||
# Standard Packages
|
||||
import json
|
||||
import argparse
|
||||
import pathlib
|
||||
import glob
|
||||
import re
|
||||
import logging
|
||||
import time
|
||||
|
||||
# Internal Packages
|
||||
from src.utils.helpers import get_absolute_path, is_none_or_empty
|
||||
from src.utils.helpers import get_absolute_path, is_none_or_empty, mark_entries_for_update
|
||||
from src.utils.constants import empty_escape_sequences
|
||||
from src.utils.jsonl import dump_jsonl, compress_jsonl_data
|
||||
from src.utils.rawconfig import TextContentConfig
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Define Functions
|
||||
def markdown_to_jsonl(markdown_files, markdown_file_filter, output_file, verbose=0):
|
||||
def markdown_to_jsonl(config: TextContentConfig, previous_entries=None):
|
||||
# Extract required fields from config
|
||||
markdown_files, markdown_file_filter, output_file = config.input_files, config.input_filter, config.compressed_jsonl
|
||||
|
||||
# Input Validation
|
||||
if is_none_or_empty(markdown_files) and is_none_or_empty(markdown_file_filter):
|
||||
print("At least one of markdown-files or markdown-file-filter is required to be specified")
|
||||
exit(1)
|
||||
|
||||
# Get Markdown Files to Process
|
||||
markdown_files = get_markdown_files(markdown_files, markdown_file_filter, verbose)
|
||||
markdown_files = get_markdown_files(markdown_files, markdown_file_filter)
|
||||
|
||||
# Extract Entries from specified Markdown files
|
||||
entries = extract_markdown_entries(markdown_files)
|
||||
start = time.time()
|
||||
current_entries = convert_markdown_entries_to_maps(*extract_markdown_entries(markdown_files))
|
||||
end = time.time()
|
||||
logger.debug(f"Parse entries from Markdown files into dictionaries: {end - start} seconds")
|
||||
|
||||
# Identify, mark and merge any new entries with previous entries
|
||||
start = time.time()
|
||||
if not previous_entries:
|
||||
entries_with_ids = list(enumerate(current_entries))
|
||||
else:
|
||||
entries_with_ids = mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger)
|
||||
end = time.time()
|
||||
logger.debug(f"Identify new or updated entries: {end - start} seconds")
|
||||
|
||||
# Process Each Entry from All Notes Files
|
||||
jsonl_data = convert_markdown_entries_to_jsonl(entries, verbose=verbose)
|
||||
start = time.time()
|
||||
entries = list(map(lambda entry: entry[1], entries_with_ids))
|
||||
jsonl_data = convert_markdown_maps_to_jsonl(entries)
|
||||
|
||||
# Compress JSONL formatted Data
|
||||
if output_file.suffix == ".gz":
|
||||
compress_jsonl_data(jsonl_data, output_file, verbose=verbose)
|
||||
compress_jsonl_data(jsonl_data, output_file)
|
||||
elif output_file.suffix == ".jsonl":
|
||||
dump_jsonl(jsonl_data, output_file, verbose=verbose)
|
||||
dump_jsonl(jsonl_data, output_file)
|
||||
end = time.time()
|
||||
logger.debug(f"Write markdown entries to JSONL file: {end - start} seconds")
|
||||
|
||||
return entries
|
||||
return entries_with_ids
|
||||
|
||||
|
||||
def get_markdown_files(markdown_files=None, markdown_file_filter=None, verbose=0):
|
||||
def get_markdown_files(markdown_files=None, markdown_file_filters=None):
|
||||
"Get Markdown files to process"
|
||||
absolute_markdown_files, filtered_markdown_files = set(), set()
|
||||
if markdown_files:
|
||||
absolute_markdown_files = {get_absolute_path(markdown_file) for markdown_file in markdown_files}
|
||||
if markdown_file_filter:
|
||||
filtered_markdown_files = set(glob.glob(get_absolute_path(markdown_file_filter)))
|
||||
if markdown_file_filters:
|
||||
filtered_markdown_files = {
|
||||
filtered_file
|
||||
for markdown_file_filter in markdown_file_filters
|
||||
for filtered_file in glob.glob(get_absolute_path(markdown_file_filter))
|
||||
}
|
||||
|
||||
all_markdown_files = absolute_markdown_files | filtered_markdown_files
|
||||
all_markdown_files = sorted(absolute_markdown_files | filtered_markdown_files)
|
||||
|
||||
files_with_non_markdown_extensions = {
|
||||
md_file
|
||||
|
@ -56,10 +83,9 @@ def get_markdown_files(markdown_files=None, markdown_file_filter=None, verbose=0
|
|||
}
|
||||
|
||||
if any(files_with_non_markdown_extensions):
|
||||
print(f"[Warning] There maybe non markdown-mode files in the input set: {files_with_non_markdown_extensions}")
|
||||
logger.warn(f"[Warning] There maybe non markdown-mode files in the input set: {files_with_non_markdown_extensions}")
|
||||
|
||||
if verbose > 0:
|
||||
print(f'Processing files: {all_markdown_files}')
|
||||
logger.info(f'Processing files: {all_markdown_files}')
|
||||
|
||||
return all_markdown_files
|
||||
|
||||
|
@ -71,38 +97,31 @@ def extract_markdown_entries(markdown_files):
|
|||
markdown_heading_regex = r'^#'
|
||||
|
||||
entries = []
|
||||
entry_to_file_map = []
|
||||
for markdown_file in markdown_files:
|
||||
with open(markdown_file) as f:
|
||||
markdown_content = f.read()
|
||||
entries.extend([f'#{entry.strip(empty_escape_sequences)}'
|
||||
markdown_entries_per_file = [f'#{entry.strip(empty_escape_sequences)}'
|
||||
for entry
|
||||
in re.split(markdown_heading_regex, markdown_content, flags=re.MULTILINE)])
|
||||
in re.split(markdown_heading_regex, markdown_content, flags=re.MULTILINE)
|
||||
if entry.strip(empty_escape_sequences) != '']
|
||||
entry_to_file_map += zip(markdown_entries_per_file, [markdown_file]*len(markdown_entries_per_file))
|
||||
entries.extend(markdown_entries_per_file)
|
||||
|
||||
return entries
|
||||
return entries, dict(entry_to_file_map)
|
||||
|
||||
|
||||
def convert_markdown_entries_to_jsonl(entries, verbose=0):
|
||||
"Convert each Markdown entries to JSON and collate as JSONL"
|
||||
jsonl = ''
|
||||
def convert_markdown_entries_to_maps(entries: list[str], entry_to_file_map) -> list[dict]:
|
||||
"Convert each Markdown entries into a dictionary"
|
||||
entry_maps = []
|
||||
for entry in entries:
|
||||
entry_dict = {'compiled': entry, 'raw': entry}
|
||||
# Convert Dictionary to JSON and Append to JSONL string
|
||||
jsonl += f'{json.dumps(entry_dict, ensure_ascii=False)}\n'
|
||||
entry_maps.append({'compiled': entry, 'raw': entry, 'file': f'{entry_to_file_map[entry]}'})
|
||||
|
||||
if verbose > 0:
|
||||
print(f"Converted {len(entries)} to jsonl format")
|
||||
logger.info(f"Converted {len(entries)} markdown entries to dictionaries")
|
||||
|
||||
return jsonl
|
||||
return entry_maps
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Setup Argument Parser
|
||||
parser = argparse.ArgumentParser(description="Map Markdown entries into (compressed) JSONL format")
|
||||
parser.add_argument('--output-file', '-o', type=pathlib.Path, required=True, help="Output file for (compressed) JSONL formatted notes. Expected file extensions: jsonl or jsonl.gz")
|
||||
parser.add_argument('--input-files', '-i', nargs='*', help="List of markdown files to process")
|
||||
parser.add_argument('--input-filter', type=str, default=None, help="Regex filter for markdown files to process")
|
||||
parser.add_argument('--verbose', '-v', action='count', default=0, help="Show verbose conversion logs, Default: 0")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Map notes in Markdown files to (compressed) JSONL formatted file
|
||||
markdown_to_jsonl(args.input_files, args.input_filter, args.output_file, args.verbose)
|
||||
def convert_markdown_maps_to_jsonl(entries):
|
||||
"Convert each Markdown entries to JSON and collate as JSONL"
|
||||
return ''.join([f'{json.dumps(entry_dict, ensure_ascii=False)}\n' for entry_dict in entries])
|
||||
|
|
|
@ -1,62 +1,94 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
# Standard Packages
|
||||
import re
|
||||
import json
|
||||
import argparse
|
||||
import pathlib
|
||||
import glob
|
||||
import logging
|
||||
import time
|
||||
from typing import Iterable
|
||||
|
||||
# Internal Packages
|
||||
from src.processor.org_mode import orgnode
|
||||
from src.utils.helpers import get_absolute_path, is_none_or_empty
|
||||
from src.utils.constants import empty_escape_sequences
|
||||
from src.utils.helpers import get_absolute_path, is_none_or_empty, mark_entries_for_update
|
||||
from src.utils.jsonl import dump_jsonl, compress_jsonl_data
|
||||
from src.utils import state
|
||||
from src.utils.rawconfig import TextContentConfig
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Define Functions
|
||||
def org_to_jsonl(org_files, org_file_filter, output_file, verbose=0):
|
||||
def org_to_jsonl(config: TextContentConfig, previous_entries=None):
|
||||
# Extract required fields from config
|
||||
org_files, org_file_filter, output_file = config.input_files, config.input_filter, config.compressed_jsonl
|
||||
index_heading_entries = config.index_heading_entries
|
||||
|
||||
# Input Validation
|
||||
if is_none_or_empty(org_files) and is_none_or_empty(org_file_filter):
|
||||
print("At least one of org-files or org-file-filter is required to be specified")
|
||||
exit(1)
|
||||
|
||||
# Get Org Files to Process
|
||||
org_files = get_org_files(org_files, org_file_filter, verbose)
|
||||
start = time.time()
|
||||
org_files = get_org_files(org_files, org_file_filter)
|
||||
|
||||
# Extract Entries from specified Org files
|
||||
entries = extract_org_entries(org_files)
|
||||
start = time.time()
|
||||
entry_nodes, file_to_entries = extract_org_entries(org_files)
|
||||
end = time.time()
|
||||
logger.debug(f"Parse entries from org files into OrgNode objects: {end - start} seconds")
|
||||
|
||||
start = time.time()
|
||||
current_entries = convert_org_nodes_to_entries(entry_nodes, file_to_entries, index_heading_entries)
|
||||
end = time.time()
|
||||
logger.debug(f"Convert OrgNodes into entry dictionaries: {end - start} seconds")
|
||||
|
||||
# Identify, mark and merge any new entries with previous entries
|
||||
if not previous_entries:
|
||||
entries_with_ids = list(enumerate(current_entries))
|
||||
else:
|
||||
entries_with_ids = mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger)
|
||||
|
||||
# Process Each Entry from All Notes Files
|
||||
jsonl_data = convert_org_entries_to_jsonl(entries, verbose=verbose)
|
||||
start = time.time()
|
||||
entries = map(lambda entry: entry[1], entries_with_ids)
|
||||
jsonl_data = convert_org_entries_to_jsonl(entries)
|
||||
|
||||
# Compress JSONL formatted Data
|
||||
if output_file.suffix == ".gz":
|
||||
compress_jsonl_data(jsonl_data, output_file, verbose=verbose)
|
||||
compress_jsonl_data(jsonl_data, output_file)
|
||||
elif output_file.suffix == ".jsonl":
|
||||
dump_jsonl(jsonl_data, output_file, verbose=verbose)
|
||||
dump_jsonl(jsonl_data, output_file)
|
||||
end = time.time()
|
||||
logger.debug(f"Write org entries to JSONL file: {end - start} seconds")
|
||||
|
||||
return entries
|
||||
return entries_with_ids
|
||||
|
||||
|
||||
def get_org_files(org_files=None, org_file_filter=None, verbose=0):
|
||||
def get_org_files(org_files=None, org_file_filters=None):
|
||||
"Get Org files to process"
|
||||
absolute_org_files, filtered_org_files = set(), set()
|
||||
if org_files:
|
||||
absolute_org_files = {get_absolute_path(org_file)
|
||||
absolute_org_files = {
|
||||
get_absolute_path(org_file)
|
||||
for org_file
|
||||
in org_files}
|
||||
if org_file_filter:
|
||||
filtered_org_files = set(glob.glob(get_absolute_path(org_file_filter)))
|
||||
in org_files
|
||||
}
|
||||
if org_file_filters:
|
||||
filtered_org_files = {
|
||||
filtered_file
|
||||
for org_file_filter in org_file_filters
|
||||
for filtered_file in glob.glob(get_absolute_path(org_file_filter))
|
||||
}
|
||||
|
||||
all_org_files = absolute_org_files | filtered_org_files
|
||||
all_org_files = sorted(absolute_org_files | filtered_org_files)
|
||||
|
||||
files_with_non_org_extensions = {org_file for org_file in all_org_files if not org_file.endswith(".org")}
|
||||
if any(files_with_non_org_extensions):
|
||||
print(f"[Warning] There maybe non org-mode files in the input set: {files_with_non_org_extensions}")
|
||||
logger.warn(f"There maybe non org-mode files in the input set: {files_with_non_org_extensions}")
|
||||
|
||||
if verbose > 0:
|
||||
print(f'Processing files: {all_org_files}')
|
||||
logger.info(f'Processing files: {all_org_files}')
|
||||
|
||||
return all_org_files
|
||||
|
||||
|
@ -64,69 +96,60 @@ def get_org_files(org_files=None, org_file_filter=None, verbose=0):
|
|||
def extract_org_entries(org_files):
|
||||
"Extract entries from specified Org files"
|
||||
entries = []
|
||||
entry_to_file_map = []
|
||||
for org_file in org_files:
|
||||
entries.extend(
|
||||
orgnode.makelist(
|
||||
str(org_file)))
|
||||
org_file_entries = orgnode.makelist(str(org_file))
|
||||
entry_to_file_map += zip(org_file_entries, [org_file]*len(org_file_entries))
|
||||
entries.extend(org_file_entries)
|
||||
|
||||
return entries
|
||||
return entries, dict(entry_to_file_map)
|
||||
|
||||
|
||||
def convert_org_entries_to_jsonl(entries, verbose=0) -> str:
|
||||
"Convert each Org-Mode entries to JSON and collate as JSONL"
|
||||
jsonl = ''
|
||||
def convert_org_nodes_to_entries(entries: list[orgnode.Orgnode], entry_to_file_map, index_heading_entries=False) -> list[dict]:
|
||||
"Convert Org-Mode entries into list of dictionary"
|
||||
entry_maps = []
|
||||
for entry in entries:
|
||||
entry_dict = dict()
|
||||
|
||||
if not entry.hasBody and not index_heading_entries:
|
||||
# Ignore title notes i.e notes with just headings and empty body
|
||||
if not entry.Body() or re.sub(r'\n|\t|\r| ', '', entry.Body()) == "":
|
||||
continue
|
||||
|
||||
entry_dict["compiled"] = f'{entry.Heading()}.'
|
||||
if verbose > 2:
|
||||
print(f"Title: {entry.Heading()}")
|
||||
entry_dict["compiled"] = f'{entry.heading}.'
|
||||
if state.verbose > 2:
|
||||
logger.debug(f"Title: {entry.heading}")
|
||||
|
||||
if entry.Tags():
|
||||
tags_str = " ".join(entry.Tags())
|
||||
if entry.tags:
|
||||
tags_str = " ".join(entry.tags)
|
||||
entry_dict["compiled"] += f'\t {tags_str}.'
|
||||
if verbose > 2:
|
||||
print(f"Tags: {tags_str}")
|
||||
if state.verbose > 2:
|
||||
logger.debug(f"Tags: {tags_str}")
|
||||
|
||||
if entry.Closed():
|
||||
entry_dict["compiled"] += f'\n Closed on {entry.Closed().strftime("%Y-%m-%d")}.'
|
||||
if verbose > 2:
|
||||
print(f'Closed: {entry.Closed().strftime("%Y-%m-%d")}')
|
||||
if entry.closed:
|
||||
entry_dict["compiled"] += f'\n Closed on {entry.closed.strftime("%Y-%m-%d")}.'
|
||||
if state.verbose > 2:
|
||||
logger.debug(f'Closed: {entry.closed.strftime("%Y-%m-%d")}')
|
||||
|
||||
if entry.Scheduled():
|
||||
entry_dict["compiled"] += f'\n Scheduled for {entry.Scheduled().strftime("%Y-%m-%d")}.'
|
||||
if verbose > 2:
|
||||
print(f'Scheduled: {entry.Scheduled().strftime("%Y-%m-%d")}')
|
||||
if entry.scheduled:
|
||||
entry_dict["compiled"] += f'\n Scheduled for {entry.scheduled.strftime("%Y-%m-%d")}.'
|
||||
if state.verbose > 2:
|
||||
logger.debug(f'Scheduled: {entry.scheduled.strftime("%Y-%m-%d")}')
|
||||
|
||||
if entry.Body():
|
||||
entry_dict["compiled"] += f'\n {entry.Body()}'
|
||||
if verbose > 2:
|
||||
print(f"Body: {entry.Body()}")
|
||||
if entry.hasBody:
|
||||
entry_dict["compiled"] += f'\n {entry.body}'
|
||||
if state.verbose > 2:
|
||||
logger.debug(f"Body: {entry.body}")
|
||||
|
||||
if entry_dict:
|
||||
entry_dict["raw"] = f'{entry}'
|
||||
entry_dict["file"] = f'{entry_to_file_map[entry]}'
|
||||
|
||||
# Convert Dictionary to JSON and Append to JSONL string
|
||||
jsonl += f'{json.dumps(entry_dict, ensure_ascii=False)}\n'
|
||||
entry_maps.append(entry_dict)
|
||||
|
||||
if verbose > 0:
|
||||
print(f"Converted {len(entries)} to jsonl format")
|
||||
|
||||
return jsonl
|
||||
return entry_maps
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Setup Argument Parser
|
||||
parser = argparse.ArgumentParser(description="Map Org-Mode notes into (compressed) JSONL format")
|
||||
parser.add_argument('--output-file', '-o', type=pathlib.Path, required=True, help="Output file for (compressed) JSONL formatted notes. Expected file extensions: jsonl or jsonl.gz")
|
||||
parser.add_argument('--input-files', '-i', nargs='*', help="List of org-mode files to process")
|
||||
parser.add_argument('--input-filter', type=str, default=None, help="Regex filter for org-mode files to process")
|
||||
parser.add_argument('--verbose', '-v', action='count', default=0, help="Show verbose conversion logs, Default: 0")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Map notes in Org-Mode files to (compressed) JSONL formatted file
|
||||
org_to_jsonl(args.input_files, args.input_filter, args.output_file, args.verbose)
|
||||
def convert_org_entries_to_jsonl(entries: Iterable[dict]) -> str:
|
||||
"Convert each Org-Mode entry to JSON and collate as JSONL"
|
||||
return ''.join([f'{json.dumps(entry_dict, ensure_ascii=False)}\n' for entry_dict in entries])
|
||||
|
|
|
@ -33,16 +33,21 @@ headline and associated text from an org-mode file, and routines for
|
|||
constructing data structures of these classes.
|
||||
"""
|
||||
|
||||
import re, sys
|
||||
import re
|
||||
import datetime
|
||||
from pathlib import Path
|
||||
from os.path import relpath
|
||||
|
||||
indent_regex = re.compile(r'^\s*')
|
||||
indent_regex = re.compile(r'^ *')
|
||||
|
||||
def normalize_filename(filename):
|
||||
file_relative_to_home = f'~/{relpath(filename, start=Path.home())}'
|
||||
escaped_filename = f'{file_relative_to_home}'.replace("[","\[").replace("]","\]")
|
||||
"Normalize and escape filename for rendering"
|
||||
if not Path(filename).is_absolute():
|
||||
# Normalize relative filename to be relative to current directory
|
||||
normalized_filename = f'~/{relpath(filename, start=Path.home())}'
|
||||
else:
|
||||
normalized_filename = filename
|
||||
escaped_filename = f'{normalized_filename}'.replace("[","\[").replace("]","\]")
|
||||
return escaped_filename
|
||||
|
||||
def makelist(filename):
|
||||
|
@ -52,65 +57,71 @@ def makelist(filename):
|
|||
"""
|
||||
ctr = 0
|
||||
|
||||
try:
|
||||
f = open(filename, 'r')
|
||||
except IOError:
|
||||
print(f"Unable to open file {filename}")
|
||||
print("Program terminating.")
|
||||
sys.exit(1)
|
||||
|
||||
todos = { "TODO": "", "WAITING": "", "ACTIVE": "",
|
||||
"DONE": "", "CANCELLED": "", "FAILED": ""} # populated from #+SEQ_TODO line
|
||||
level = 0
|
||||
level = ""
|
||||
heading = ""
|
||||
bodytext = ""
|
||||
tags = set() # set of all tags in headline
|
||||
tags = list() # set of all tags in headline
|
||||
closed_date = ''
|
||||
sched_date = ''
|
||||
deadline_date = ''
|
||||
logbook = list()
|
||||
nodelist = []
|
||||
propdict = dict()
|
||||
nodelist: list[Orgnode] = list()
|
||||
property_map = dict()
|
||||
in_properties_drawer = False
|
||||
in_logbook_drawer = False
|
||||
file_title = f'{filename}'
|
||||
|
||||
for line in f:
|
||||
ctr += 1
|
||||
hdng = re.search(r'^(\*+)\s(.*?)\s*$', line)
|
||||
if hdng: # we are processing a heading line
|
||||
heading_search = re.search(r'^(\*+)\s(.*?)\s*$', line)
|
||||
if heading_search: # we are processing a heading line
|
||||
if heading: # if we have are on second heading, append first heading to headings list
|
||||
thisNode = Orgnode(level, heading, bodytext, tags)
|
||||
if closed_date:
|
||||
thisNode.setClosed(closed_date)
|
||||
thisNode.closed = closed_date
|
||||
closed_date = ''
|
||||
if sched_date:
|
||||
thisNode.setScheduled(sched_date)
|
||||
thisNode.scheduled = sched_date
|
||||
sched_date = ""
|
||||
if deadline_date:
|
||||
thisNode.setDeadline(deadline_date)
|
||||
thisNode.deadline = deadline_date
|
||||
deadline_date = ''
|
||||
if logbook:
|
||||
thisNode.setLogbook(logbook)
|
||||
thisNode.logbook = logbook
|
||||
logbook = list()
|
||||
thisNode.setProperties(propdict)
|
||||
thisNode.properties = property_map
|
||||
nodelist.append( thisNode )
|
||||
propdict = {'LINE': f'file:{normalize_filename(filename)}::{ctr}'}
|
||||
level = hdng.group(1)
|
||||
heading = hdng.group(2)
|
||||
property_map = {'LINE': f'file:{normalize_filename(filename)}::{ctr}'}
|
||||
level = heading_search.group(1)
|
||||
heading = heading_search.group(2)
|
||||
bodytext = ""
|
||||
tags = set() # set of all tags in headline
|
||||
tagsrch = re.search(r'(.*?)\s*:([a-zA-Z0-9].*?):$',heading)
|
||||
if tagsrch:
|
||||
heading = tagsrch.group(1)
|
||||
parsedtags = tagsrch.group(2)
|
||||
tags = list() # set of all tags in headline
|
||||
tag_search = re.search(r'(.*?)\s*:([a-zA-Z0-9].*?):$',heading)
|
||||
if tag_search:
|
||||
heading = tag_search.group(1)
|
||||
parsedtags = tag_search.group(2)
|
||||
if parsedtags:
|
||||
for parsedtag in parsedtags.split(':'):
|
||||
if parsedtag != '': tags.add(parsedtag)
|
||||
if parsedtag != '': tags.append(parsedtag)
|
||||
else: # we are processing a non-heading line
|
||||
if line[:10] == '#+SEQ_TODO':
|
||||
kwlist = re.findall(r'([A-Z]+)\(', line)
|
||||
for kw in kwlist: todos[kw] = ""
|
||||
|
||||
# Set file title to TITLE property, if it exists
|
||||
title_search = re.search(r'^#\+TITLE:\s*(.*)$', line)
|
||||
if title_search and title_search.group(1).strip() != '':
|
||||
title_text = title_search.group(1).strip()
|
||||
if file_title == f'{filename}':
|
||||
file_title = title_text
|
||||
else:
|
||||
file_title += f' {title_text}'
|
||||
continue
|
||||
|
||||
# Ignore Properties Drawers Completely
|
||||
if re.search(':PROPERTIES:', line):
|
||||
in_properties_drawer=True
|
||||
|
@ -137,13 +148,13 @@ def makelist(filename):
|
|||
logbook += [(clocked_in, clocked_out)]
|
||||
line = ""
|
||||
|
||||
prop_srch = re.search(r'^\s*:([a-zA-Z0-9]+):\s*(.*?)\s*$', line)
|
||||
if prop_srch:
|
||||
property_search = re.search(r'^\s*:([a-zA-Z0-9]+):\s*(.*?)\s*$', line)
|
||||
if property_search:
|
||||
# Set ID property to an id based org-mode link to the entry
|
||||
if prop_srch.group(1) == 'ID':
|
||||
propdict['ID'] = f'id:{prop_srch.group(2)}'
|
||||
if property_search.group(1) == 'ID':
|
||||
property_map['ID'] = f'id:{property_search.group(2)}'
|
||||
else:
|
||||
propdict[prop_srch.group(1)] = prop_srch.group(2)
|
||||
property_map[property_search.group(1)] = property_search.group(2)
|
||||
continue
|
||||
|
||||
cd_re = re.search(r'CLOSED:\s*\[([0-9]{4})-([0-9]{2})-([0-9]{2})', line)
|
||||
|
@ -167,36 +178,39 @@ def makelist(filename):
|
|||
bodytext = bodytext + line
|
||||
|
||||
# write out last node
|
||||
thisNode = Orgnode(level, heading, bodytext, tags)
|
||||
thisNode.setProperties(propdict)
|
||||
thisNode = Orgnode(level, heading or file_title, bodytext, tags)
|
||||
thisNode.properties = property_map
|
||||
if sched_date:
|
||||
thisNode.setScheduled(sched_date)
|
||||
thisNode.scheduled = sched_date
|
||||
if deadline_date:
|
||||
thisNode.setDeadline(deadline_date)
|
||||
thisNode.deadline = deadline_date
|
||||
if closed_date:
|
||||
thisNode.setClosed(closed_date)
|
||||
thisNode.closed = closed_date
|
||||
if logbook:
|
||||
thisNode.setLogbook(logbook)
|
||||
thisNode.logbook = logbook
|
||||
nodelist.append( thisNode )
|
||||
|
||||
# using the list of TODO keywords found in the file
|
||||
# process the headings searching for TODO keywords
|
||||
for n in nodelist:
|
||||
h = n.Heading()
|
||||
todoSrch = re.search(r'([A-Z]+)\s(.*?)$', h)
|
||||
if todoSrch:
|
||||
if todoSrch.group(1) in todos:
|
||||
n.setHeading( todoSrch.group(2) )
|
||||
n.setTodo ( todoSrch.group(1) )
|
||||
todo_search = re.search(r'([A-Z]+)\s(.*?)$', n.heading)
|
||||
if todo_search:
|
||||
if todo_search.group(1) in todos:
|
||||
n.heading = todo_search.group(2)
|
||||
n.todo = todo_search.group(1)
|
||||
|
||||
# extract, set priority from heading, update heading if necessary
|
||||
prtysrch = re.search(r'^\[\#(A|B|C)\] (.*?)$', n.Heading())
|
||||
if prtysrch:
|
||||
n.setPriority(prtysrch.group(1))
|
||||
n.setHeading(prtysrch.group(2))
|
||||
priority_search = re.search(r'^\[\#(A|B|C)\] (.*?)$', n.heading)
|
||||
if priority_search:
|
||||
n.priority = priority_search.group(1)
|
||||
n.heading = priority_search.group(2)
|
||||
|
||||
# Set SOURCE property to a file+heading based org-mode link to the entry
|
||||
escaped_heading = n.Heading().replace("[","\\[").replace("]","\\]")
|
||||
if n.level == 0:
|
||||
n.properties['LINE'] = f'file:{normalize_filename(filename)}::0'
|
||||
n.properties['SOURCE'] = f'[[file:{normalize_filename(filename)}]]'
|
||||
else:
|
||||
escaped_heading = n.heading.replace("[","\\[").replace("]","\\]")
|
||||
n.properties['SOURCE'] = f'[[file:{normalize_filename(filename)}::*{escaped_heading}]]'
|
||||
|
||||
return nodelist
|
||||
|
@ -214,199 +228,234 @@ class Orgnode(object):
|
|||
first tag. The makelist routine postprocesses the list to
|
||||
identify TODO tags and updates headline and todo fields.
|
||||
"""
|
||||
self.level = len(level)
|
||||
self.headline = headline
|
||||
self.body = body
|
||||
self.tags = set(tags) # All tags in the headline
|
||||
self.todo = ""
|
||||
self.prty = "" # empty of A, B or C
|
||||
self.scheduled = "" # Scheduled date
|
||||
self.deadline = "" # Deadline date
|
||||
self.closed = "" # Closed date
|
||||
self.properties = dict()
|
||||
self.logbook = list() # List of clock-in, clock-out tuples representing logbook entries
|
||||
self._level = len(level)
|
||||
self._heading = headline
|
||||
self._body = body
|
||||
self._tags = tags # All tags in the headline
|
||||
self._todo = ""
|
||||
self._priority = "" # empty of A, B or C
|
||||
self._scheduled = "" # Scheduled date
|
||||
self._deadline = "" # Deadline date
|
||||
self._closed = "" # Closed date
|
||||
self._properties = dict()
|
||||
self._logbook = list() # List of clock-in, clock-out tuples representing logbook entries
|
||||
|
||||
# Look for priority in headline and transfer to prty field
|
||||
|
||||
def Heading(self):
|
||||
@property
|
||||
def heading(self):
|
||||
"""
|
||||
Return the Heading text of the node without the TODO tag
|
||||
"""
|
||||
return self.headline
|
||||
return self._heading
|
||||
|
||||
def setHeading(self, newhdng):
|
||||
@heading.setter
|
||||
def heading(self, newhdng):
|
||||
"""
|
||||
Change the heading to the supplied string
|
||||
"""
|
||||
self.headline = newhdng
|
||||
self._heading = newhdng
|
||||
|
||||
def Body(self):
|
||||
@property
|
||||
def body(self):
|
||||
"""
|
||||
Returns all lines of text of the body of this node except the
|
||||
Property Drawer
|
||||
"""
|
||||
return self.body
|
||||
return self._body
|
||||
|
||||
def Level(self):
|
||||
@property
|
||||
def hasBody(self):
|
||||
"""
|
||||
Returns True if node has non empty body, else False
|
||||
"""
|
||||
return self._body and re.sub(r'\n|\t|\r| ', '', self._body) != ''
|
||||
|
||||
@property
|
||||
def level(self):
|
||||
"""
|
||||
Returns an integer corresponding to the level of the node.
|
||||
Top level (one asterisk) has a level of 1.
|
||||
"""
|
||||
return self.level
|
||||
return self._level
|
||||
|
||||
def Priority(self):
|
||||
@property
|
||||
def priority(self):
|
||||
"""
|
||||
Returns the priority of this headline: 'A', 'B', 'C' or empty
|
||||
string if priority has not been set.
|
||||
"""
|
||||
return self.prty
|
||||
return self._priority
|
||||
|
||||
def setPriority(self, newprty):
|
||||
@priority.setter
|
||||
def priority(self, new_priority):
|
||||
"""
|
||||
Change the value of the priority of this headline.
|
||||
Values values are '', 'A', 'B', 'C'
|
||||
"""
|
||||
self.prty = newprty
|
||||
self._priority = new_priority
|
||||
|
||||
def Tags(self):
|
||||
@property
|
||||
def tags(self):
|
||||
"""
|
||||
Returns the set of all tags
|
||||
For example, :HOME:COMPUTER: would return {'HOME', 'COMPUTER'}
|
||||
Returns the list of all tags
|
||||
For example, :HOME:COMPUTER: would return ['HOME', 'COMPUTER']
|
||||
"""
|
||||
return self.tags
|
||||
return self._tags
|
||||
|
||||
def hasTag(self, srch):
|
||||
@tags.setter
|
||||
def tags(self, newtags):
|
||||
"""
|
||||
Store all the tags found in the headline.
|
||||
"""
|
||||
self._tags = newtags
|
||||
|
||||
def hasTag(self, tag):
|
||||
"""
|
||||
Returns True if the supplied tag is present in this headline
|
||||
For example, hasTag('COMPUTER') on headling containing
|
||||
:HOME:COMPUTER: would return True.
|
||||
"""
|
||||
return srch in self.tags
|
||||
return tag in self._tags
|
||||
|
||||
def setTags(self, newtags):
|
||||
"""
|
||||
Store all the tags found in the headline.
|
||||
"""
|
||||
self.tags = set(newtags)
|
||||
|
||||
def Todo(self):
|
||||
@property
|
||||
def todo(self):
|
||||
"""
|
||||
Return the value of the TODO tag
|
||||
"""
|
||||
return self.todo
|
||||
return self._todo
|
||||
|
||||
def setTodo(self, value):
|
||||
@todo.setter
|
||||
def todo(self, new_todo):
|
||||
"""
|
||||
Set the value of the TODO tag to the supplied string
|
||||
"""
|
||||
self.todo = value
|
||||
self._todo = new_todo
|
||||
|
||||
def setProperties(self, dictval):
|
||||
@property
|
||||
def properties(self):
|
||||
"""
|
||||
Return the dictionary of properties
|
||||
"""
|
||||
return self._properties
|
||||
|
||||
@properties.setter
|
||||
def properties(self, new_properties):
|
||||
"""
|
||||
Sets all properties using the supplied dictionary of
|
||||
name/value pairs
|
||||
"""
|
||||
self.properties = dictval
|
||||
self._properties = new_properties
|
||||
|
||||
def Property(self, keyval):
|
||||
def Property(self, property_key):
|
||||
"""
|
||||
Returns the value of the requested property or null if the
|
||||
property does not exist.
|
||||
"""
|
||||
return self.properties.get(keyval, "")
|
||||
return self._properties.get(property_key, "")
|
||||
|
||||
def setScheduled(self, dateval):
|
||||
@property
|
||||
def scheduled(self):
|
||||
"""
|
||||
Set the scheduled date using the supplied date object
|
||||
Return the scheduled date
|
||||
"""
|
||||
self.scheduled = dateval
|
||||
return self._scheduled
|
||||
|
||||
def Scheduled(self):
|
||||
@scheduled.setter
|
||||
def scheduled(self, new_scheduled):
|
||||
"""
|
||||
Return the scheduled date object or null if nonexistent
|
||||
Set the scheduled date to the scheduled date
|
||||
"""
|
||||
return self.scheduled
|
||||
self._scheduled = new_scheduled
|
||||
|
||||
def setDeadline(self, dateval):
|
||||
@property
|
||||
def deadline(self):
|
||||
"""
|
||||
Set the deadline (due) date using the supplied date object
|
||||
Return the deadline date
|
||||
"""
|
||||
self.deadline = dateval
|
||||
return self._deadline
|
||||
|
||||
def Deadline(self):
|
||||
@deadline.setter
|
||||
def deadline(self, new_deadline):
|
||||
"""
|
||||
Return the deadline date object or null if nonexistent
|
||||
Set the deadline (due) date to the new deadline date
|
||||
"""
|
||||
return self.deadline
|
||||
self._deadline = new_deadline
|
||||
|
||||
def setClosed(self, dateval):
|
||||
@property
|
||||
def closed(self):
|
||||
"""
|
||||
Set the closed date using the supplied date object
|
||||
Return the closed date
|
||||
"""
|
||||
self.closed = dateval
|
||||
return self._closed
|
||||
|
||||
def Closed(self):
|
||||
@closed.setter
|
||||
def closed(self, new_closed):
|
||||
"""
|
||||
Return the closed date object or null if nonexistent
|
||||
Set the closed date to the new closed date
|
||||
"""
|
||||
return self.closed
|
||||
self._closed = new_closed
|
||||
|
||||
def setLogbook(self, logbook):
|
||||
"""
|
||||
Set the logbook with list of clocked-in, clocked-out tuples for the entry
|
||||
"""
|
||||
self.logbook = logbook
|
||||
|
||||
def Logbook(self):
|
||||
@property
|
||||
def logbook(self):
|
||||
"""
|
||||
Return the logbook with all clocked-in, clocked-out date object pairs or empty list if nonexistent
|
||||
"""
|
||||
return self.logbook
|
||||
return self._logbook
|
||||
|
||||
@logbook.setter
|
||||
def logbook(self, new_logbook):
|
||||
"""
|
||||
Set the logbook with list of clocked-in, clocked-out tuples for the entry
|
||||
"""
|
||||
self._logbook = new_logbook
|
||||
|
||||
def __repr__(self):
|
||||
"""
|
||||
Print the level, heading text and tag of a node and the body
|
||||
text as used to construct the node.
|
||||
"""
|
||||
# This method is not completed yet.
|
||||
# Output heading line
|
||||
n = ''
|
||||
for _ in range(0, self.level):
|
||||
for _ in range(0, self._level):
|
||||
n = n + '*'
|
||||
n = n + ' '
|
||||
if self.todo:
|
||||
n = n + self.todo + ' '
|
||||
if self.prty:
|
||||
n = n + '[#' + self.prty + '] '
|
||||
n = n + self.headline
|
||||
if self._todo:
|
||||
n = n + self._todo + ' '
|
||||
if self._priority:
|
||||
n = n + '[#' + self._priority + '] '
|
||||
n = n + self._heading
|
||||
n = "%-60s " % n # hack - tags will start in column 62
|
||||
closecolon = ''
|
||||
for t in self.tags:
|
||||
for t in self._tags:
|
||||
n = n + ':' + t
|
||||
closecolon = ':'
|
||||
n = n + closecolon
|
||||
n = n + "\n"
|
||||
|
||||
# Get body indentation from first line of body
|
||||
indent = indent_regex.match(self.body).group()
|
||||
indent = indent_regex.match(self._body).group()
|
||||
|
||||
# Output Closed Date, Scheduled Date, Deadline Date
|
||||
if self.closed or self.scheduled or self.deadline:
|
||||
if self._closed or self._scheduled or self._deadline:
|
||||
n = n + indent
|
||||
if self.closed:
|
||||
n = n + f'CLOSED: [{self.closed.strftime("%Y-%m-%d %a")}] '
|
||||
if self.scheduled:
|
||||
n = n + f'SCHEDULED: <{self.scheduled.strftime("%Y-%m-%d %a")}> '
|
||||
if self.deadline:
|
||||
n = n + f'DEADLINE: <{self.deadline.strftime("%Y-%m-%d %a")}> '
|
||||
if self.closed or self.scheduled or self.deadline:
|
||||
if self._closed:
|
||||
n = n + f'CLOSED: [{self._closed.strftime("%Y-%m-%d %a")}] '
|
||||
if self._scheduled:
|
||||
n = n + f'SCHEDULED: <{self._scheduled.strftime("%Y-%m-%d %a")}> '
|
||||
if self._deadline:
|
||||
n = n + f'DEADLINE: <{self._deadline.strftime("%Y-%m-%d %a")}> '
|
||||
if self._closed or self._scheduled or self._deadline:
|
||||
n = n + '\n'
|
||||
|
||||
# Ouput Property Drawer
|
||||
n = n + indent + ":PROPERTIES:\n"
|
||||
for key, value in self.properties.items():
|
||||
for key, value in self._properties.items():
|
||||
n = n + indent + f":{key}: {value}\n"
|
||||
n = n + indent + ":END:\n"
|
||||
|
||||
n = n + self.body
|
||||
# Output Body
|
||||
if self.hasBody:
|
||||
n = n + self._body
|
||||
|
||||
return n
|
||||
|
|
|
@ -2,8 +2,8 @@
|
|||
import yaml
|
||||
import json
|
||||
import time
|
||||
import logging
|
||||
from typing import Optional
|
||||
from functools import lru_cache
|
||||
|
||||
# External Packages
|
||||
from fastapi import APIRouter
|
||||
|
@ -15,16 +15,17 @@ from fastapi.templating import Jinja2Templates
|
|||
from src.configure import configure_search
|
||||
from src.search_type import image_search, text_search
|
||||
from src.processor.conversation.gpt import converse, extract_search_type, message_to_log, message_to_prompt, understand, summarize
|
||||
from src.search_filter.explicit_filter import ExplicitFilter
|
||||
from src.search_filter.date_filter import DateFilter
|
||||
from src.utils.rawconfig import FullConfig
|
||||
from src.utils.config import SearchType
|
||||
from src.utils.helpers import get_absolute_path, get_from_dict
|
||||
from src.utils.helpers import LRU, get_absolute_path, get_from_dict
|
||||
from src.utils import state, constants
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
router = APIRouter()
|
||||
templates = Jinja2Templates(directory=constants.web_directory)
|
||||
logger = logging.getLogger(__name__)
|
||||
query_cache = LRU()
|
||||
|
||||
|
||||
@router.get("/", response_class=FileResponse)
|
||||
def index():
|
||||
|
@ -47,22 +48,27 @@ async def config_data(updated_config: FullConfig):
|
|||
return state.config
|
||||
|
||||
@router.get('/search')
|
||||
@lru_cache(maxsize=100)
|
||||
def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Optional[bool] = False):
|
||||
if q is None or q == '':
|
||||
print(f'No query param (q) passed in API call to initiate search')
|
||||
logger.info(f'No query param (q) passed in API call to initiate search')
|
||||
return {}
|
||||
|
||||
# initialize variables
|
||||
user_query = q
|
||||
user_query = q.strip()
|
||||
results_count = n
|
||||
results = {}
|
||||
query_start, query_end, collate_start, collate_end = None, None, None, None
|
||||
|
||||
# return cached results, if available
|
||||
query_cache_key = f'{user_query}-{n}-{t}-{r}'
|
||||
if query_cache_key in state.query_cache:
|
||||
logger.info(f'Return response from query cache')
|
||||
return state.query_cache[query_cache_key]
|
||||
|
||||
if (t == SearchType.Org or t == None) and state.model.orgmode_search:
|
||||
# query org-mode notes
|
||||
query_start = time.time()
|
||||
hits, entries = text_search.query(user_query, state.model.orgmode_search, rank_results=r, filters=[DateFilter(), ExplicitFilter()], verbose=state.verbose)
|
||||
hits, entries = text_search.query(user_query, state.model.orgmode_search, rank_results=r)
|
||||
query_end = time.time()
|
||||
|
||||
# collate and return results
|
||||
|
@ -73,7 +79,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
|
|||
if (t == SearchType.Music or t == None) and state.model.music_search:
|
||||
# query music library
|
||||
query_start = time.time()
|
||||
hits, entries = text_search.query(user_query, state.model.music_search, rank_results=r, filters=[DateFilter(), ExplicitFilter()], verbose=state.verbose)
|
||||
hits, entries = text_search.query(user_query, state.model.music_search, rank_results=r)
|
||||
query_end = time.time()
|
||||
|
||||
# collate and return results
|
||||
|
@ -84,7 +90,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
|
|||
if (t == SearchType.Markdown or t == None) and state.model.markdown_search:
|
||||
# query markdown files
|
||||
query_start = time.time()
|
||||
hits, entries = text_search.query(user_query, state.model.markdown_search, rank_results=r, filters=[ExplicitFilter(), DateFilter()], verbose=state.verbose)
|
||||
hits, entries = text_search.query(user_query, state.model.markdown_search, rank_results=r)
|
||||
query_end = time.time()
|
||||
|
||||
# collate and return results
|
||||
|
@ -95,7 +101,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
|
|||
if (t == SearchType.Ledger or t == None) and state.model.ledger_search:
|
||||
# query transactions
|
||||
query_start = time.time()
|
||||
hits, entries = text_search.query(user_query, state.model.ledger_search, rank_results=r, filters=[ExplicitFilter(), DateFilter()], verbose=state.verbose)
|
||||
hits, entries = text_search.query(user_query, state.model.ledger_search, rank_results=r)
|
||||
query_end = time.time()
|
||||
|
||||
# collate and return results
|
||||
|
@ -131,11 +137,13 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
|
|||
count=results_count)
|
||||
collate_end = time.time()
|
||||
|
||||
if state.verbose > 1:
|
||||
# Cache results
|
||||
state.query_cache[query_cache_key] = results
|
||||
|
||||
if query_start and query_end:
|
||||
print(f"Query took {query_end - query_start:.3f} seconds")
|
||||
logger.debug(f"Query took {query_end - query_start:.3f} seconds")
|
||||
if collate_start and collate_end:
|
||||
print(f"Collating results took {collate_end - collate_start:.3f} seconds")
|
||||
logger.debug(f"Collating results took {collate_end - collate_start:.3f} seconds")
|
||||
|
||||
return results
|
||||
|
||||
|
|
16
src/search_filter/base_filter.py
Normal file
16
src/search_filter/base_filter.py
Normal file
|
@ -0,0 +1,16 @@
|
|||
# Standard Packages
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class BaseFilter(ABC):
|
||||
@abstractmethod
|
||||
def load(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def can_filter(self, raw_query:str) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def apply(self, query:str, raw_entries:list[str]) -> tuple[str, set[int]]:
|
||||
pass
|
|
@ -1,15 +1,24 @@
|
|||
# Standard Packages
|
||||
import re
|
||||
import time
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from datetime import timedelta, datetime
|
||||
from dateutil.relativedelta import relativedelta, MO
|
||||
from dateutil.relativedelta import relativedelta
|
||||
from math import inf
|
||||
|
||||
# External Packages
|
||||
import torch
|
||||
import dateparser as dtparse
|
||||
|
||||
# Internal Packages
|
||||
from src.search_filter.base_filter import BaseFilter
|
||||
from src.utils.helpers import LRU
|
||||
|
||||
class DateFilter:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DateFilter(BaseFilter):
|
||||
# Date Range Filter Regexes
|
||||
# Example filter queries:
|
||||
# - dt>="yesterday" dt<"tomorrow"
|
||||
|
@ -17,46 +26,73 @@ class DateFilter:
|
|||
# - dt:"2 years ago"
|
||||
date_regex = r"dt([:><=]{1,2})\"(.*?)\""
|
||||
|
||||
def can_filter(self, raw_query):
|
||||
"Check if query contains date filters"
|
||||
return self.extract_date_range(raw_query) is not None
|
||||
|
||||
def __init__(self, entry_key='raw'):
|
||||
self.entry_key = entry_key
|
||||
self.date_to_entry_ids = defaultdict(set)
|
||||
self.cache = LRU()
|
||||
|
||||
|
||||
def filter(self, query, entries, embeddings, entry_key='raw'):
|
||||
"Find entries containing any dates that fall within date range specified in query"
|
||||
# extract date range specified in date filter of query
|
||||
query_daterange = self.extract_date_range(query)
|
||||
|
||||
# if no date in query, return all entries
|
||||
if query_daterange is None:
|
||||
return query, entries, embeddings
|
||||
|
||||
# remove date range filter from query
|
||||
query = re.sub(rf'\s+{self.date_regex}', ' ', query)
|
||||
query = re.sub(r'\s{2,}', ' ', query).strip() # remove multiple spaces
|
||||
|
||||
# find entries containing any dates that fall with date range specified in query
|
||||
entries_to_include = set()
|
||||
def load(self, entries, **_):
|
||||
start = time.time()
|
||||
for id, entry in enumerate(entries):
|
||||
# Extract dates from entry
|
||||
for date_in_entry_string in re.findall(r'\d{4}-\d{2}-\d{2}', entry[entry_key]):
|
||||
for date_in_entry_string in re.findall(r'\d{4}-\d{2}-\d{2}', entry[self.entry_key]):
|
||||
# Convert date string in entry to unix timestamp
|
||||
try:
|
||||
date_in_entry = datetime.strptime(date_in_entry_string, '%Y-%m-%d').timestamp()
|
||||
except ValueError:
|
||||
continue
|
||||
self.date_to_entry_ids[date_in_entry].add(id)
|
||||
end = time.time()
|
||||
logger.debug(f"Created date filter index: {end - start} seconds")
|
||||
|
||||
|
||||
def can_filter(self, raw_query):
|
||||
"Check if query contains date filters"
|
||||
return self.extract_date_range(raw_query) is not None
|
||||
|
||||
|
||||
def apply(self, query, raw_entries):
|
||||
"Find entries containing any dates that fall within date range specified in query"
|
||||
# extract date range specified in date filter of query
|
||||
start = time.time()
|
||||
query_daterange = self.extract_date_range(query)
|
||||
end = time.time()
|
||||
logger.debug(f"Extract date range to filter from query: {end - start} seconds")
|
||||
|
||||
# if no date in query, return all entries
|
||||
if query_daterange is None:
|
||||
return query, set(range(len(raw_entries)))
|
||||
|
||||
# remove date range filter from query
|
||||
query = re.sub(rf'\s+{self.date_regex}', ' ', query)
|
||||
query = re.sub(r'\s{2,}', ' ', query).strip() # remove multiple spaces
|
||||
|
||||
# return results from cache if exists
|
||||
cache_key = tuple(query_daterange)
|
||||
if cache_key in self.cache:
|
||||
logger.info(f"Return date filter results from cache")
|
||||
entries_to_include = self.cache[cache_key]
|
||||
return query, entries_to_include
|
||||
|
||||
if not self.date_to_entry_ids:
|
||||
self.load(raw_entries)
|
||||
|
||||
# find entries containing any dates that fall with date range specified in query
|
||||
start = time.time()
|
||||
entries_to_include = set()
|
||||
for date_in_entry in self.date_to_entry_ids.keys():
|
||||
# Check if date in entry is within date range specified in query
|
||||
if query_daterange[0] <= date_in_entry < query_daterange[1]:
|
||||
entries_to_include.add(id)
|
||||
break
|
||||
entries_to_include |= self.date_to_entry_ids[date_in_entry]
|
||||
end = time.time()
|
||||
logger.debug(f"Mark entries satisfying filter: {end - start} seconds")
|
||||
|
||||
# delete entries (and their embeddings) marked for exclusion
|
||||
entries_to_exclude = set(range(len(entries))) - entries_to_include
|
||||
for id in sorted(list(entries_to_exclude), reverse=True):
|
||||
del entries[id]
|
||||
embeddings = torch.cat((embeddings[:id], embeddings[id+1:]))
|
||||
# cache results
|
||||
self.cache[cache_key] = entries_to_include
|
||||
|
||||
return query, entries, embeddings
|
||||
return query, entries_to_include
|
||||
|
||||
|
||||
def extract_date_range(self, query):
|
||||
|
|
|
@ -1,57 +0,0 @@
|
|||
# Standard Packages
|
||||
import re
|
||||
|
||||
# External Packages
|
||||
import torch
|
||||
|
||||
|
||||
class ExplicitFilter:
|
||||
def can_filter(self, raw_query):
|
||||
"Check if query contains explicit filters"
|
||||
# Extract explicit query portion with required, blocked words to filter from natural query
|
||||
required_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("+")])
|
||||
blocked_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("-")])
|
||||
|
||||
return len(required_words) != 0 or len(blocked_words) != 0
|
||||
|
||||
|
||||
def filter(self, raw_query, entries, embeddings, entry_key='raw'):
|
||||
"Find entries containing required and not blocked words specified in query"
|
||||
# Separate natural query from explicit required, blocked words filters
|
||||
query = " ".join([word for word in raw_query.split() if not word.startswith("+") and not word.startswith("-")])
|
||||
required_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("+")])
|
||||
blocked_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("-")])
|
||||
|
||||
if len(required_words) == 0 and len(blocked_words) == 0:
|
||||
return query, entries, embeddings
|
||||
|
||||
# convert each entry to a set of words
|
||||
# split on fullstop, comma, colon, tab, newline or any brackets
|
||||
entry_splitter = r',|\.| |\]|\[\(|\)|\{|\}|\t|\n|\:'
|
||||
entries_by_word_set = [set(word.lower()
|
||||
for word
|
||||
in re.split(entry_splitter, entry[entry_key])
|
||||
if word != "")
|
||||
for entry in entries]
|
||||
|
||||
# track id of entries to exclude
|
||||
entries_to_exclude = set()
|
||||
|
||||
# mark entries that do not contain all required_words for exclusion
|
||||
if len(required_words) > 0:
|
||||
for id, words_in_entry in enumerate(entries_by_word_set):
|
||||
if not required_words.issubset(words_in_entry):
|
||||
entries_to_exclude.add(id)
|
||||
|
||||
# mark entries that contain any blocked_words for exclusion
|
||||
if len(blocked_words) > 0:
|
||||
for id, words_in_entry in enumerate(entries_by_word_set):
|
||||
if words_in_entry.intersection(blocked_words):
|
||||
entries_to_exclude.add(id)
|
||||
|
||||
# delete entries (and their embeddings) marked for exclusion
|
||||
for id in sorted(list(entries_to_exclude), reverse=True):
|
||||
del entries[id]
|
||||
embeddings = torch.cat((embeddings[:id], embeddings[id+1:]))
|
||||
|
||||
return query, entries, embeddings
|
79
src/search_filter/file_filter.py
Normal file
79
src/search_filter/file_filter.py
Normal file
|
@ -0,0 +1,79 @@
|
|||
# Standard Packages
|
||||
import re
|
||||
import fnmatch
|
||||
import time
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
|
||||
# Internal Packages
|
||||
from src.search_filter.base_filter import BaseFilter
|
||||
from src.utils.helpers import LRU
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FileFilter(BaseFilter):
|
||||
file_filter_regex = r'file:"(.+?)" ?'
|
||||
|
||||
def __init__(self, entry_key='file'):
|
||||
self.entry_key = entry_key
|
||||
self.file_to_entry_map = defaultdict(set)
|
||||
self.cache = LRU()
|
||||
|
||||
def load(self, entries, *args, **kwargs):
|
||||
start = time.time()
|
||||
for id, entry in enumerate(entries):
|
||||
self.file_to_entry_map[entry[self.entry_key]].add(id)
|
||||
end = time.time()
|
||||
logger.debug(f"Created file filter index: {end - start} seconds")
|
||||
|
||||
def can_filter(self, raw_query):
|
||||
return re.search(self.file_filter_regex, raw_query) is not None
|
||||
|
||||
def apply(self, raw_query, raw_entries):
|
||||
# Extract file filters from raw query
|
||||
start = time.time()
|
||||
raw_files_to_search = re.findall(self.file_filter_regex, raw_query)
|
||||
if not raw_files_to_search:
|
||||
return raw_query, set(range(len(raw_entries)))
|
||||
|
||||
# Convert simple file filters with no path separator into regex
|
||||
# e.g. "file:notes.org" -> "file:.*notes.org"
|
||||
files_to_search = []
|
||||
for file in sorted(raw_files_to_search):
|
||||
if '/' not in file and '\\' not in file and '*' not in file:
|
||||
files_to_search += [f'*{file}']
|
||||
else:
|
||||
files_to_search += [file]
|
||||
end = time.time()
|
||||
logger.debug(f"Extract files_to_search from query: {end - start} seconds")
|
||||
|
||||
# Return item from cache if exists
|
||||
query = re.sub(self.file_filter_regex, '', raw_query).strip()
|
||||
cache_key = tuple(files_to_search)
|
||||
if cache_key in self.cache:
|
||||
logger.info(f"Return file filter results from cache")
|
||||
included_entry_indices = self.cache[cache_key]
|
||||
return query, included_entry_indices
|
||||
|
||||
if not self.file_to_entry_map:
|
||||
self.load(raw_entries, regenerate=False)
|
||||
|
||||
# Mark entries that contain any blocked_words for exclusion
|
||||
start = time.time()
|
||||
|
||||
included_entry_indices = set.union(*[self.file_to_entry_map[entry_file]
|
||||
for entry_file in self.file_to_entry_map.keys()
|
||||
for search_file in files_to_search
|
||||
if fnmatch.fnmatch(entry_file, search_file)], set())
|
||||
if not included_entry_indices:
|
||||
return query, {}
|
||||
|
||||
end = time.time()
|
||||
logger.debug(f"Mark entries satisfying filter: {end - start} seconds")
|
||||
|
||||
# Cache results
|
||||
self.cache[cache_key] = included_entry_indices
|
||||
|
||||
return query, included_entry_indices
|
96
src/search_filter/word_filter.py
Normal file
96
src/search_filter/word_filter.py
Normal file
|
@ -0,0 +1,96 @@
|
|||
# Standard Packages
|
||||
import re
|
||||
import time
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
|
||||
# Internal Packages
|
||||
from src.search_filter.base_filter import BaseFilter
|
||||
from src.utils.helpers import LRU
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WordFilter(BaseFilter):
|
||||
# Filter Regex
|
||||
required_regex = r'\+"([a-zA-Z0-9_-]+)" ?'
|
||||
blocked_regex = r'\-"([a-zA-Z0-9_-]+)" ?'
|
||||
|
||||
def __init__(self, entry_key='raw'):
|
||||
self.entry_key = entry_key
|
||||
self.word_to_entry_index = defaultdict(set)
|
||||
self.cache = LRU()
|
||||
|
||||
|
||||
def load(self, entries, regenerate=False):
|
||||
start = time.time()
|
||||
self.cache = {} # Clear cache on filter (re-)load
|
||||
entry_splitter = r',|\.| |\]|\[\(|\)|\{|\}|\<|\>|\t|\n|\:|\;|\?|\!|\(|\)|\&|\^|\$|\@|\%|\+|\=|\/|\\|\||\~|\`|\"|\''
|
||||
# Create map of words to entries they exist in
|
||||
for entry_index, entry in enumerate(entries):
|
||||
for word in re.split(entry_splitter, entry[self.entry_key].lower()):
|
||||
if word == '':
|
||||
continue
|
||||
self.word_to_entry_index[word].add(entry_index)
|
||||
end = time.time()
|
||||
logger.debug(f"Created word filter index: {end - start} seconds")
|
||||
|
||||
return self.word_to_entry_index
|
||||
|
||||
|
||||
def can_filter(self, raw_query):
|
||||
"Check if query contains word filters"
|
||||
required_words = re.findall(self.required_regex, raw_query)
|
||||
blocked_words = re.findall(self.blocked_regex, raw_query)
|
||||
|
||||
return len(required_words) != 0 or len(blocked_words) != 0
|
||||
|
||||
|
||||
def apply(self, raw_query, raw_entries):
|
||||
"Find entries containing required and not blocked words specified in query"
|
||||
# Separate natural query from required, blocked words filters
|
||||
start = time.time()
|
||||
|
||||
required_words = set([word.lower() for word in re.findall(self.required_regex, raw_query)])
|
||||
blocked_words = set([word.lower() for word in re.findall(self.blocked_regex, raw_query)])
|
||||
query = re.sub(self.blocked_regex, '', re.sub(self.required_regex, '', raw_query)).strip()
|
||||
|
||||
end = time.time()
|
||||
logger.debug(f"Extract required, blocked filters from query: {end - start} seconds")
|
||||
|
||||
if len(required_words) == 0 and len(blocked_words) == 0:
|
||||
return query, set(range(len(raw_entries)))
|
||||
|
||||
# Return item from cache if exists
|
||||
cache_key = tuple(sorted(required_words)), tuple(sorted(blocked_words))
|
||||
if cache_key in self.cache:
|
||||
logger.info(f"Return word filter results from cache")
|
||||
included_entry_indices = self.cache[cache_key]
|
||||
return query, included_entry_indices
|
||||
|
||||
if not self.word_to_entry_index:
|
||||
self.load(raw_entries, regenerate=False)
|
||||
|
||||
start = time.time()
|
||||
|
||||
# mark entries that contain all required_words for inclusion
|
||||
entries_with_all_required_words = set(range(len(raw_entries)))
|
||||
if len(required_words) > 0:
|
||||
entries_with_all_required_words = set.intersection(*[self.word_to_entry_index.get(word, set()) for word in required_words])
|
||||
|
||||
# mark entries that contain any blocked_words for exclusion
|
||||
entries_with_any_blocked_words = set()
|
||||
if len(blocked_words) > 0:
|
||||
entries_with_any_blocked_words = set.union(*[self.word_to_entry_index.get(word, set()) for word in blocked_words])
|
||||
|
||||
end = time.time()
|
||||
logger.debug(f"Mark entries satisfying filter: {end - start} seconds")
|
||||
|
||||
# get entries satisfying inclusion and exclusion filters
|
||||
included_entry_indices = entries_with_all_required_words - entries_with_any_blocked_words
|
||||
|
||||
# Cache results
|
||||
self.cache[cache_key] = included_entry_indices
|
||||
|
||||
return query, included_entry_indices
|
|
@ -1,9 +1,10 @@
|
|||
# Standard Packages
|
||||
import argparse
|
||||
import glob
|
||||
import pathlib
|
||||
import copy
|
||||
import shutil
|
||||
import time
|
||||
import logging
|
||||
|
||||
# External Packages
|
||||
from sentence_transformers import SentenceTransformer, util
|
||||
|
@ -18,6 +19,10 @@ from src.utils.config import ImageSearchModel
|
|||
from src.utils.rawconfig import ImageContentConfig, ImageSearchConfig
|
||||
|
||||
|
||||
# Create Logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def initialize_model(search_config: ImageSearchConfig):
|
||||
# Initialize Model
|
||||
torch.set_num_threads(4)
|
||||
|
@ -37,41 +42,43 @@ def initialize_model(search_config: ImageSearchConfig):
|
|||
return encoder
|
||||
|
||||
|
||||
def extract_entries(image_directories, verbose=0):
|
||||
def extract_entries(image_directories):
|
||||
image_names = []
|
||||
for image_directory in image_directories:
|
||||
image_directory = resolve_absolute_path(image_directory, strict=True)
|
||||
image_names.extend(list(image_directory.glob('*.jpg')))
|
||||
image_names.extend(list(image_directory.glob('*.jpeg')))
|
||||
|
||||
if verbose > 0:
|
||||
if logger.level >= logging.INFO:
|
||||
image_directory_names = ', '.join([str(image_directory) for image_directory in image_directories])
|
||||
print(f'Found {len(image_names)} images in {image_directory_names}')
|
||||
logger.info(f'Found {len(image_names)} images in {image_directory_names}')
|
||||
return sorted(image_names)
|
||||
|
||||
|
||||
def compute_embeddings(image_names, encoder, embeddings_file, batch_size=50, use_xmp_metadata=False, regenerate=False, verbose=0):
|
||||
def compute_embeddings(image_names, encoder, embeddings_file, batch_size=50, use_xmp_metadata=False, regenerate=False):
|
||||
"Compute (and Save) Embeddings or Load Pre-Computed Embeddings"
|
||||
|
||||
image_embeddings = compute_image_embeddings(image_names, encoder, embeddings_file, batch_size, regenerate, verbose)
|
||||
image_metadata_embeddings = compute_metadata_embeddings(image_names, encoder, embeddings_file, batch_size, use_xmp_metadata, regenerate, verbose)
|
||||
image_embeddings = compute_image_embeddings(image_names, encoder, embeddings_file, batch_size, regenerate)
|
||||
image_metadata_embeddings = compute_metadata_embeddings(image_names, encoder, embeddings_file, batch_size, use_xmp_metadata, regenerate)
|
||||
|
||||
return image_embeddings, image_metadata_embeddings
|
||||
|
||||
|
||||
def compute_image_embeddings(image_names, encoder, embeddings_file, batch_size=50, regenerate=False, verbose=0):
|
||||
image_embeddings = None
|
||||
|
||||
def compute_image_embeddings(image_names, encoder, embeddings_file, batch_size=50, regenerate=False):
|
||||
# Load pre-computed image embeddings from file if exists
|
||||
if resolve_absolute_path(embeddings_file).exists() and not regenerate:
|
||||
image_embeddings = torch.load(embeddings_file)
|
||||
if verbose > 0:
|
||||
print(f"Loaded pre-computed embeddings from {embeddings_file}")
|
||||
logger.info(f"Loaded {len(image_embeddings)} image embeddings from {embeddings_file}")
|
||||
# Else compute the image embeddings from scratch, which can take a while
|
||||
elif image_embeddings is None:
|
||||
else:
|
||||
image_embeddings = []
|
||||
for index in trange(0, len(image_names), batch_size):
|
||||
images = [Image.open(image_name) for image_name in image_names[index:index+batch_size]]
|
||||
images = []
|
||||
for image_name in image_names[index:index+batch_size]:
|
||||
image = Image.open(image_name)
|
||||
# Resize images to max width of 640px for faster processing
|
||||
image.thumbnail((640, image.height))
|
||||
images += [image]
|
||||
image_embeddings += encoder.encode(
|
||||
images,
|
||||
convert_to_tensor=True,
|
||||
|
@ -82,8 +89,7 @@ def compute_image_embeddings(image_names, encoder, embeddings_file, batch_size=5
|
|||
|
||||
# Save computed image embeddings to file
|
||||
torch.save(image_embeddings, embeddings_file)
|
||||
if verbose > 0:
|
||||
print(f"Saved computed embeddings to {embeddings_file}")
|
||||
logger.info(f"Saved computed embeddings to {embeddings_file}")
|
||||
|
||||
return image_embeddings
|
||||
|
||||
|
@ -94,8 +100,7 @@ def compute_metadata_embeddings(image_names, encoder, embeddings_file, batch_siz
|
|||
# Load pre-computed image metadata embedding file if exists
|
||||
if use_xmp_metadata and resolve_absolute_path(f"{embeddings_file}_metadata").exists() and not regenerate:
|
||||
image_metadata_embeddings = torch.load(f"{embeddings_file}_metadata")
|
||||
if verbose > 0:
|
||||
print(f"Loaded pre-computed embeddings from {embeddings_file}_metadata")
|
||||
logger.info(f"Loaded pre-computed embeddings from {embeddings_file}_metadata")
|
||||
|
||||
# Else compute the image metadata embeddings from scratch, which can take a while
|
||||
if use_xmp_metadata and image_metadata_embeddings is None:
|
||||
|
@ -108,16 +113,15 @@ def compute_metadata_embeddings(image_names, encoder, embeddings_file, batch_siz
|
|||
convert_to_tensor=True,
|
||||
batch_size=min(len(image_metadata), batch_size))
|
||||
except RuntimeError as e:
|
||||
print(f"Error encoding metadata for images starting from\n\tindex: {index},\n\timages: {image_names[index:index+batch_size]}\nException: {e}")
|
||||
logger.error(f"Error encoding metadata for images starting from\n\tindex: {index},\n\timages: {image_names[index:index+batch_size]}\nException: {e}")
|
||||
continue
|
||||
torch.save(image_metadata_embeddings, f"{embeddings_file}_metadata")
|
||||
if verbose > 0:
|
||||
print(f"Saved computed metadata embeddings to {embeddings_file}_metadata")
|
||||
logger.info(f"Saved computed metadata embeddings to {embeddings_file}_metadata")
|
||||
|
||||
return image_metadata_embeddings
|
||||
|
||||
|
||||
def extract_metadata(image_name, verbose=0):
|
||||
def extract_metadata(image_name):
|
||||
with exiftool.ExifTool() as et:
|
||||
image_metadata = et.get_tags(["XMP:Subject", "XMP:Description"], str(image_name))
|
||||
image_metadata_subjects = set([subject.split(":")[1] for subject in image_metadata.get("XMP:Subject", "") if ":" in subject])
|
||||
|
@ -126,8 +130,7 @@ def extract_metadata(image_name, verbose=0):
|
|||
if len(image_metadata_subjects) > 0:
|
||||
image_processed_metadata += ". " + ", ".join(image_metadata_subjects)
|
||||
|
||||
if verbose > 2:
|
||||
print(f"{image_name}:\t{image_processed_metadata}")
|
||||
logger.debug(f"{image_name}:\t{image_processed_metadata}")
|
||||
|
||||
return image_processed_metadata
|
||||
|
||||
|
@ -137,26 +140,34 @@ def query(raw_query, count, model: ImageSearchModel):
|
|||
if pathlib.Path(raw_query).is_file():
|
||||
query_imagepath = resolve_absolute_path(pathlib.Path(raw_query), strict=True)
|
||||
query = copy.deepcopy(Image.open(query_imagepath))
|
||||
if model.verbose > 0:
|
||||
print(f"Find Images similar to Image at {query_imagepath}")
|
||||
query.thumbnail((640, query.height)) # scale down image for faster processing
|
||||
logger.info(f"Find Images similar to Image at {query_imagepath}")
|
||||
else:
|
||||
query = raw_query
|
||||
if model.verbose > 0:
|
||||
print(f"Find Images by Text: {query}")
|
||||
logger.info(f"Find Images by Text: {query}")
|
||||
|
||||
# Now we encode the query (which can either be an image or a text string)
|
||||
start = time.time()
|
||||
query_embedding = model.image_encoder.encode([query], convert_to_tensor=True, show_progress_bar=False)
|
||||
end = time.time()
|
||||
logger.debug(f"Query Encode Time: {end - start:.3f} seconds")
|
||||
|
||||
# Compute top_k ranked images based on cosine-similarity b/w query and all image embeddings.
|
||||
start = time.time()
|
||||
image_hits = {result['corpus_id']: {'image_score': result['score'], 'score': result['score']}
|
||||
for result
|
||||
in util.semantic_search(query_embedding, model.image_embeddings, top_k=count)[0]}
|
||||
end = time.time()
|
||||
logger.debug(f"Search Time: {end - start:.3f} seconds")
|
||||
|
||||
# Compute top_k ranked images based on cosine-similarity b/w query and all image metadata embeddings.
|
||||
if model.image_metadata_embeddings:
|
||||
start = time.time()
|
||||
metadata_hits = {result['corpus_id']: result['score']
|
||||
for result
|
||||
in util.semantic_search(query_embedding, model.image_metadata_embeddings, top_k=count)[0]}
|
||||
end = time.time()
|
||||
logger.debug(f"Metadata Search Time: {end - start:.3f} seconds")
|
||||
|
||||
# Sum metadata, image scores of the highest ranked images
|
||||
for corpus_id, score in metadata_hits.items():
|
||||
|
@ -219,7 +230,7 @@ def collate_results(hits, image_names, output_directory, image_files_url, count=
|
|||
return results
|
||||
|
||||
|
||||
def setup(config: ImageContentConfig, search_config: ImageSearchConfig, regenerate: bool, verbose: bool=False) -> ImageSearchModel:
|
||||
def setup(config: ImageContentConfig, search_config: ImageSearchConfig, regenerate: bool) -> ImageSearchModel:
|
||||
# Initialize Model
|
||||
encoder = initialize_model(search_config)
|
||||
|
||||
|
@ -227,9 +238,13 @@ def setup(config: ImageContentConfig, search_config: ImageSearchConfig, regenera
|
|||
absolute_image_files, filtered_image_files = set(), set()
|
||||
if config.input_directories:
|
||||
image_directories = [resolve_absolute_path(directory, strict=True) for directory in config.input_directories]
|
||||
absolute_image_files = set(extract_entries(image_directories, verbose))
|
||||
absolute_image_files = set(extract_entries(image_directories))
|
||||
if config.input_filter:
|
||||
filtered_image_files = set(glob.glob(get_absolute_path(config.input_filter)))
|
||||
filtered_image_files = {
|
||||
filtered_file
|
||||
for input_filter in config.input_filter
|
||||
for filtered_file in glob.glob(get_absolute_path(input_filter))
|
||||
}
|
||||
|
||||
all_image_files = sorted(list(absolute_image_files | filtered_image_files))
|
||||
|
||||
|
@ -241,38 +256,9 @@ def setup(config: ImageContentConfig, search_config: ImageSearchConfig, regenera
|
|||
embeddings_file,
|
||||
batch_size=config.batch_size,
|
||||
regenerate=regenerate,
|
||||
use_xmp_metadata=config.use_xmp_metadata,
|
||||
verbose=verbose)
|
||||
use_xmp_metadata=config.use_xmp_metadata)
|
||||
|
||||
return ImageSearchModel(all_image_files,
|
||||
image_embeddings,
|
||||
image_metadata_embeddings,
|
||||
encoder,
|
||||
verbose)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Setup Argument Parser
|
||||
parser = argparse.ArgumentParser(description="Semantic Search on Images")
|
||||
parser.add_argument('--image-directory', '-i', required=True, type=pathlib.Path, help="Image directory to query")
|
||||
parser.add_argument('--embeddings-file', '-e', default='image_embeddings.pt', type=pathlib.Path, help="File to save/load model embeddings to/from. Default: ./embeddings.pt")
|
||||
parser.add_argument('--regenerate', action='store_true', default=False, help="Regenerate embeddings of Images in Image Directory . Default: false")
|
||||
parser.add_argument('--results-count', '-n', default=5, type=int, help="Number of results to render. Default: 5")
|
||||
parser.add_argument('--interactive', action='store_true', default=False, help="Interactive mode allows user to run queries on the model. Default: true")
|
||||
parser.add_argument('--verbose', action='count', default=0, help="Show verbose conversion logs. Default: 0")
|
||||
args = parser.parse_args()
|
||||
|
||||
image_names, image_embeddings, image_metadata_embeddings, model = setup(args.image_directory, args.embeddings_file, regenerate=args.regenerate)
|
||||
|
||||
# Run User Queries on Entries in Interactive Mode
|
||||
while args.interactive:
|
||||
# get query from user
|
||||
user_query = input("Enter your query: ")
|
||||
if user_query == "exit":
|
||||
exit(0)
|
||||
|
||||
# query images
|
||||
hits = query(user_query, image_embeddings, image_metadata_embeddings, model, args.results_count, args.verbose)
|
||||
|
||||
# render results
|
||||
render_results(hits, image_names, args.image_directory, count=args.results_count)
|
||||
encoder)
|
||||
|
|
|
@ -1,21 +1,23 @@
|
|||
# Standard Packages
|
||||
import argparse
|
||||
import pathlib
|
||||
from copy import deepcopy
|
||||
import logging
|
||||
import time
|
||||
|
||||
# External Packages
|
||||
import torch
|
||||
from sentence_transformers import SentenceTransformer, CrossEncoder, util
|
||||
from src.search_filter.base_filter import BaseFilter
|
||||
|
||||
# Internal Packages
|
||||
from src.utils import state
|
||||
from src.utils.helpers import get_absolute_path, resolve_absolute_path, load_model
|
||||
from src.utils.helpers import get_absolute_path, is_none_or_empty, resolve_absolute_path, load_model
|
||||
from src.utils.config import TextSearchModel
|
||||
from src.utils.rawconfig import TextSearchConfig, TextContentConfig
|
||||
from src.utils.jsonl import load_jsonl
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def initialize_model(search_config: TextSearchConfig):
|
||||
"Initialize model for semantic search on text"
|
||||
torch.set_num_threads(4)
|
||||
|
@ -46,73 +48,85 @@ def initialize_model(search_config: TextSearchConfig):
|
|||
return bi_encoder, cross_encoder, top_k
|
||||
|
||||
|
||||
def extract_entries(jsonl_file, verbose=0):
|
||||
def extract_entries(jsonl_file):
|
||||
"Load entries from compressed jsonl"
|
||||
return [{'compiled': f'{entry["compiled"]}', 'raw': f'{entry["raw"]}'}
|
||||
for entry
|
||||
in load_jsonl(jsonl_file, verbose=verbose)]
|
||||
return load_jsonl(jsonl_file)
|
||||
|
||||
|
||||
def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, verbose=0):
|
||||
def compute_embeddings(entries_with_ids, bi_encoder, embeddings_file, regenerate=False):
|
||||
"Compute (and Save) Embeddings or Load Pre-Computed Embeddings"
|
||||
# Load pre-computed embeddings from file if exists
|
||||
new_entries = []
|
||||
# Load pre-computed embeddings from file if exists and update them if required
|
||||
if embeddings_file.exists() and not regenerate:
|
||||
corpus_embeddings = torch.load(get_absolute_path(embeddings_file), map_location=state.device)
|
||||
if verbose > 0:
|
||||
print(f"Loaded embeddings from {embeddings_file}")
|
||||
logger.info(f"Loaded embeddings from {embeddings_file}")
|
||||
|
||||
else: # Else compute the corpus_embeddings from scratch, which can take a while
|
||||
corpus_embeddings = bi_encoder.encode([entry['compiled'] for entry in entries], convert_to_tensor=True, device=state.device, show_progress_bar=True)
|
||||
# Encode any new entries in the corpus and update corpus embeddings
|
||||
new_entries = [entry['compiled'] for id, entry in entries_with_ids if id is None]
|
||||
if new_entries:
|
||||
new_embeddings = bi_encoder.encode(new_entries, convert_to_tensor=True, device=state.device, show_progress_bar=True)
|
||||
existing_entry_ids = [id for id, _ in entries_with_ids if id is not None]
|
||||
existing_embeddings = torch.index_select(corpus_embeddings, 0, torch.tensor(existing_entry_ids)) if existing_entry_ids else torch.Tensor()
|
||||
corpus_embeddings = torch.cat([existing_embeddings, new_embeddings], dim=0)
|
||||
# Else compute the corpus embeddings from scratch
|
||||
else:
|
||||
new_entries = [entry['compiled'] for _, entry in entries_with_ids]
|
||||
corpus_embeddings = bi_encoder.encode(new_entries, convert_to_tensor=True, device=state.device, show_progress_bar=True)
|
||||
|
||||
# Save regenerated or updated embeddings to file
|
||||
if new_entries:
|
||||
corpus_embeddings = util.normalize_embeddings(corpus_embeddings)
|
||||
torch.save(corpus_embeddings, embeddings_file)
|
||||
if verbose > 0:
|
||||
print(f"Computed embeddings and saved them to {embeddings_file}")
|
||||
logger.info(f"Computed embeddings and saved them to {embeddings_file}")
|
||||
|
||||
return corpus_embeddings
|
||||
|
||||
|
||||
def query(raw_query: str, model: TextSearchModel, rank_results=False, filters: list = [], verbose=0):
|
||||
def query(raw_query: str, model: TextSearchModel, rank_results=False):
|
||||
"Search for entries that answer the query"
|
||||
query = raw_query
|
||||
|
||||
# Use deep copy of original embeddings, entries to filter if query contains filters
|
||||
start = time.time()
|
||||
filters_in_query = [filter for filter in filters if filter.can_filter(query)]
|
||||
if filters_in_query:
|
||||
corpus_embeddings = deepcopy(model.corpus_embeddings)
|
||||
entries = deepcopy(model.entries)
|
||||
else:
|
||||
corpus_embeddings = model.corpus_embeddings
|
||||
entries = model.entries
|
||||
end = time.time()
|
||||
if verbose > 1:
|
||||
print(f"Copy Time: {end - start:.3f} seconds")
|
||||
query, entries, corpus_embeddings = raw_query, model.entries, model.corpus_embeddings
|
||||
|
||||
# Filter query, entries and embeddings before semantic search
|
||||
start = time.time()
|
||||
start_filter = time.time()
|
||||
included_entry_indices = set(range(len(entries)))
|
||||
filters_in_query = [filter for filter in model.filters if filter.can_filter(query)]
|
||||
for filter in filters_in_query:
|
||||
query, entries, corpus_embeddings = filter.filter(query, entries, corpus_embeddings)
|
||||
query, included_entry_indices_by_filter = filter.apply(query, entries)
|
||||
included_entry_indices.intersection_update(included_entry_indices_by_filter)
|
||||
|
||||
# Get entries (and associated embeddings) satisfying all filters
|
||||
if not included_entry_indices:
|
||||
return [], []
|
||||
else:
|
||||
start = time.time()
|
||||
entries = [entries[id] for id in included_entry_indices]
|
||||
corpus_embeddings = torch.index_select(corpus_embeddings, 0, torch.tensor(list(included_entry_indices)))
|
||||
end = time.time()
|
||||
if verbose > 1:
|
||||
print(f"Filter Time: {end - start:.3f} seconds")
|
||||
logger.debug(f"Keep entries satisfying all filters: {end - start} seconds")
|
||||
|
||||
end_filter = time.time()
|
||||
logger.debug(f"Total Filter Time: {end_filter - start_filter:.3f} seconds")
|
||||
|
||||
if entries is None or len(entries) == 0:
|
||||
return [], []
|
||||
|
||||
# If query only had filters it'll be empty now. So short-circuit and return results.
|
||||
if query.strip() == "":
|
||||
hits = [{"corpus_id": id, "score": 1.0} for id, _ in enumerate(entries)]
|
||||
return hits, entries
|
||||
|
||||
# Encode the query using the bi-encoder
|
||||
start = time.time()
|
||||
question_embedding = model.bi_encoder.encode([query], convert_to_tensor=True, device=state.device)
|
||||
question_embedding = util.normalize_embeddings(question_embedding)
|
||||
end = time.time()
|
||||
if verbose > 1:
|
||||
print(f"Query Encode Time: {end - start:.3f} seconds on device: {state.device}")
|
||||
logger.debug(f"Query Encode Time: {end - start:.3f} seconds on device: {state.device}")
|
||||
|
||||
# Find relevant entries for the query
|
||||
start = time.time()
|
||||
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=model.top_k, score_function=util.dot_score)[0]
|
||||
end = time.time()
|
||||
if verbose > 1:
|
||||
print(f"Search Time: {end - start:.3f} seconds on device: {state.device}")
|
||||
logger.debug(f"Search Time: {end - start:.3f} seconds on device: {state.device}")
|
||||
|
||||
# Score all retrieved entries using the cross-encoder
|
||||
if rank_results:
|
||||
|
@ -120,8 +134,7 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False, filters: l
|
|||
cross_inp = [[query, entries[hit['corpus_id']]['compiled']] for hit in hits]
|
||||
cross_scores = model.cross_encoder.predict(cross_inp)
|
||||
end = time.time()
|
||||
if verbose > 1:
|
||||
print(f"Cross-Encoder Predict Time: {end - start:.3f} seconds on device: {state.device}")
|
||||
logger.debug(f"Cross-Encoder Predict Time: {end - start:.3f} seconds on device: {state.device}")
|
||||
|
||||
# Store cross-encoder scores in results dictionary for ranking
|
||||
for idx in range(len(cross_scores)):
|
||||
|
@ -133,8 +146,7 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False, filters: l
|
|||
if rank_results:
|
||||
hits.sort(key=lambda x: x['cross-score'], reverse=True) # sort by cross-encoder score
|
||||
end = time.time()
|
||||
if verbose > 1:
|
||||
print(f"Rank Time: {end - start:.3f} seconds on device: {state.device}")
|
||||
logger.debug(f"Rank Time: {end - start:.3f} seconds on device: {state.device}")
|
||||
|
||||
return hits, entries
|
||||
|
||||
|
@ -167,50 +179,26 @@ def collate_results(hits, entries, count=5):
|
|||
in hits[0:count]]
|
||||
|
||||
|
||||
def setup(text_to_jsonl, config: TextContentConfig, search_config: TextSearchConfig, regenerate: bool, verbose: bool=False) -> TextSearchModel:
|
||||
def setup(text_to_jsonl, config: TextContentConfig, search_config: TextSearchConfig, regenerate: bool, filters: list[BaseFilter] = []) -> TextSearchModel:
|
||||
# Initialize Model
|
||||
bi_encoder, cross_encoder, top_k = initialize_model(search_config)
|
||||
|
||||
# Map notes in text files to (compressed) JSONL formatted file
|
||||
config.compressed_jsonl = resolve_absolute_path(config.compressed_jsonl)
|
||||
if not config.compressed_jsonl.exists() or regenerate:
|
||||
text_to_jsonl(config.input_files, config.input_filter, config.compressed_jsonl, verbose)
|
||||
previous_entries = extract_entries(config.compressed_jsonl) if config.compressed_jsonl.exists() and not regenerate else None
|
||||
entries_with_indices = text_to_jsonl(config, previous_entries)
|
||||
|
||||
# Extract Entries
|
||||
entries = extract_entries(config.compressed_jsonl, verbose)
|
||||
# Extract Updated Entries
|
||||
entries = extract_entries(config.compressed_jsonl)
|
||||
if is_none_or_empty(entries):
|
||||
raise ValueError(f"No valid entries found in specified files: {config.input_files} or {config.input_filter}")
|
||||
top_k = min(len(entries), top_k) # top_k hits can't be more than the total entries in corpus
|
||||
|
||||
# Compute or Load Embeddings
|
||||
config.embeddings_file = resolve_absolute_path(config.embeddings_file)
|
||||
corpus_embeddings = compute_embeddings(entries, bi_encoder, config.embeddings_file, regenerate=regenerate, verbose=verbose)
|
||||
corpus_embeddings = compute_embeddings(entries_with_indices, bi_encoder, config.embeddings_file, regenerate=regenerate)
|
||||
|
||||
return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, top_k, verbose=verbose)
|
||||
for filter in filters:
|
||||
filter.load(entries, regenerate=regenerate)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Setup Argument Parser
|
||||
parser = argparse.ArgumentParser(description="Map Text files into (compressed) JSONL format")
|
||||
parser.add_argument('--input-files', '-i', nargs='*', help="List of Text files to process")
|
||||
parser.add_argument('--input-filter', type=str, default=None, help="Regex filter for Text files to process")
|
||||
parser.add_argument('--compressed-jsonl', '-j', type=pathlib.Path, default=pathlib.Path("text.jsonl.gz"), help="Compressed JSONL to compute embeddings from")
|
||||
parser.add_argument('--embeddings', '-e', type=pathlib.Path, default=pathlib.Path("text_embeddings.pt"), help="File to save/load model embeddings to/from")
|
||||
parser.add_argument('--regenerate', action='store_true', default=False, help="Regenerate embeddings from text files. Default: false")
|
||||
parser.add_argument('--results-count', '-n', default=5, type=int, help="Number of results to render. Default: 5")
|
||||
parser.add_argument('--interactive', action='store_true', default=False, help="Interactive mode allows user to run queries on the model. Default: true")
|
||||
parser.add_argument('--verbose', action='count', default=0, help="Show verbose conversion logs. Default: 0")
|
||||
args = parser.parse_args()
|
||||
|
||||
entries, corpus_embeddings, bi_encoder, cross_encoder, top_k = setup(args.input_files, args.input_filter, args.compressed_jsonl, args.embeddings, args.regenerate, args.verbose)
|
||||
|
||||
# Run User Queries on Entries in Interactive Mode
|
||||
while args.interactive:
|
||||
# get query from user
|
||||
user_query = input("Enter your query: ")
|
||||
if user_query == "exit":
|
||||
exit(0)
|
||||
|
||||
# query notes
|
||||
hits = query(user_query, corpus_embeddings, entries, bi_encoder, cross_encoder, top_k)
|
||||
|
||||
# render results
|
||||
render_results(hits, entries, count=args.results_count)
|
||||
return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, filters, top_k)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# Standard Packages
|
||||
import argparse
|
||||
import pathlib
|
||||
from importlib.metadata import version
|
||||
|
||||
# Internal Packages
|
||||
from src.utils.helpers import resolve_absolute_path
|
||||
|
@ -16,9 +17,15 @@ def cli(args=None):
|
|||
parser.add_argument('--host', type=str, default='127.0.0.1', help="Host address of the server. Default: 127.0.0.1")
|
||||
parser.add_argument('--port', '-p', type=int, default=8000, help="Port of the server. Default: 8000")
|
||||
parser.add_argument('--socket', type=pathlib.Path, help="Path to UNIX socket for server. Use to run server behind reverse proxy. Default: /tmp/uvicorn.sock")
|
||||
parser.add_argument('--version', '-V', action='store_true', help="Print the installed Khoj version and exit")
|
||||
|
||||
args = parser.parse_args(args)
|
||||
|
||||
if args.version:
|
||||
# Show version of khoj installed and exit
|
||||
print(version('khoj-assistant'))
|
||||
exit(0)
|
||||
|
||||
# Normalize config_file path to absolute path
|
||||
args.config_file = resolve_absolute_path(args.config_file)
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@ from pathlib import Path
|
|||
|
||||
# Internal Packages
|
||||
from src.utils.rawconfig import ConversationProcessorConfig
|
||||
from src.search_filter.base_filter import BaseFilter
|
||||
|
||||
|
||||
class SearchType(str, Enum):
|
||||
|
@ -21,23 +22,22 @@ class ProcessorType(str, Enum):
|
|||
|
||||
|
||||
class TextSearchModel():
|
||||
def __init__(self, entries, corpus_embeddings, bi_encoder, cross_encoder, top_k, verbose):
|
||||
def __init__(self, entries, corpus_embeddings, bi_encoder, cross_encoder, filters: list[BaseFilter], top_k):
|
||||
self.entries = entries
|
||||
self.corpus_embeddings = corpus_embeddings
|
||||
self.bi_encoder = bi_encoder
|
||||
self.cross_encoder = cross_encoder
|
||||
self.filters = filters
|
||||
self.top_k = top_k
|
||||
self.verbose = verbose
|
||||
|
||||
|
||||
class ImageSearchModel():
|
||||
def __init__(self, image_names, image_embeddings, image_metadata_embeddings, image_encoder, verbose):
|
||||
def __init__(self, image_names, image_embeddings, image_metadata_embeddings, image_encoder):
|
||||
self.image_encoder = image_encoder
|
||||
self.image_names = image_names
|
||||
self.image_embeddings = image_embeddings
|
||||
self.image_metadata_embeddings = image_metadata_embeddings
|
||||
self.image_encoder = image_encoder
|
||||
self.verbose = verbose
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -51,12 +51,11 @@ class SearchModels():
|
|||
|
||||
|
||||
class ConversationProcessorConfigModel():
|
||||
def __init__(self, processor_config: ConversationProcessorConfig, verbose: bool):
|
||||
def __init__(self, processor_config: ConversationProcessorConfig):
|
||||
self.openai_api_key = processor_config.openai_api_key
|
||||
self.conversation_logfile = Path(processor_config.conversation_logfile)
|
||||
self.chat_session = ''
|
||||
self.meta_log = []
|
||||
self.verbose = verbose
|
||||
self.meta_log: dict = {}
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
|
@ -2,7 +2,7 @@ from pathlib import Path
|
|||
|
||||
app_root_directory = Path(__file__).parent.parent.parent
|
||||
web_directory = app_root_directory / 'src/interface/web/'
|
||||
empty_escape_sequences = r'\n|\r\t '
|
||||
empty_escape_sequences = '\n|\r|\t| '
|
||||
|
||||
# default app config to use
|
||||
default_config = {
|
||||
|
@ -11,7 +11,8 @@ default_config = {
|
|||
'input-files': None,
|
||||
'input-filter': None,
|
||||
'compressed-jsonl': '~/.khoj/content/org/org.jsonl.gz',
|
||||
'embeddings-file': '~/.khoj/content/org/org_embeddings.pt'
|
||||
'embeddings-file': '~/.khoj/content/org/org_embeddings.pt',
|
||||
'index_heading_entries': False
|
||||
},
|
||||
'markdown': {
|
||||
'input-files': None,
|
||||
|
|
|
@ -1,7 +1,11 @@
|
|||
# Standard Packages
|
||||
import pathlib
|
||||
from pathlib import Path
|
||||
import sys
|
||||
import time
|
||||
import hashlib
|
||||
from os.path import join
|
||||
from collections import OrderedDict
|
||||
from typing import Optional, Union
|
||||
|
||||
|
||||
def is_none_or_empty(item):
|
||||
|
@ -12,12 +16,12 @@ def to_snake_case_from_dash(item: str):
|
|||
return item.replace('_', '-')
|
||||
|
||||
|
||||
def get_absolute_path(filepath):
|
||||
return str(pathlib.Path(filepath).expanduser().absolute())
|
||||
def get_absolute_path(filepath: Union[str, Path]) -> str:
|
||||
return str(Path(filepath).expanduser().absolute())
|
||||
|
||||
|
||||
def resolve_absolute_path(filepath, strict=False):
|
||||
return pathlib.Path(filepath).expanduser().absolute().resolve(strict=strict)
|
||||
def resolve_absolute_path(filepath: Union[str, Optional[Path]], strict=False) -> Path:
|
||||
return Path(filepath).expanduser().absolute().resolve(strict=strict)
|
||||
|
||||
|
||||
def get_from_dict(dictionary, *args):
|
||||
|
@ -61,3 +65,56 @@ def load_model(model_name, model_dir, model_type, device:str=None):
|
|||
def is_pyinstaller_app():
|
||||
"Returns true if the app is running from Native GUI created by PyInstaller"
|
||||
return getattr(sys, 'frozen', False) and hasattr(sys, '_MEIPASS')
|
||||
|
||||
|
||||
class LRU(OrderedDict):
|
||||
def __init__(self, *args, capacity=128, **kwargs):
|
||||
self.capacity = capacity
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def __getitem__(self, key):
|
||||
value = super().__getitem__(key)
|
||||
self.move_to_end(key)
|
||||
return value
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
super().__setitem__(key, value)
|
||||
if len(self) > self.capacity:
|
||||
oldest = next(iter(self))
|
||||
del self[oldest]
|
||||
|
||||
|
||||
def mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=None):
|
||||
# Hash all current and previous entries to identify new entries
|
||||
start = time.time()
|
||||
current_entry_hashes = list(map(lambda e: hashlib.md5(bytes(e[key], encoding='utf-8')).hexdigest(), current_entries))
|
||||
previous_entry_hashes = list(map(lambda e: hashlib.md5(bytes(e[key], encoding='utf-8')).hexdigest(), previous_entries))
|
||||
end = time.time()
|
||||
logger.debug(f"Hash previous, current entries: {end - start} seconds")
|
||||
|
||||
start = time.time()
|
||||
hash_to_current_entries = dict(zip(current_entry_hashes, current_entries))
|
||||
hash_to_previous_entries = dict(zip(previous_entry_hashes, previous_entries))
|
||||
|
||||
# All entries that did not exist in the previous set are to be added
|
||||
new_entry_hashes = set(current_entry_hashes) - set(previous_entry_hashes)
|
||||
# All entries that exist in both current and previous sets are kept
|
||||
existing_entry_hashes = set(current_entry_hashes) & set(previous_entry_hashes)
|
||||
|
||||
# Mark new entries with no ids for later embeddings generation
|
||||
new_entries = [
|
||||
(None, hash_to_current_entries[entry_hash])
|
||||
for entry_hash in new_entry_hashes
|
||||
]
|
||||
# Set id of existing entries to their previous ids to reuse their existing encoded embeddings
|
||||
existing_entries = [
|
||||
(previous_entry_hashes.index(entry_hash), hash_to_previous_entries[entry_hash])
|
||||
for entry_hash in existing_entry_hashes
|
||||
]
|
||||
|
||||
existing_entries_sorted = sorted(existing_entries, key=lambda e: e[0])
|
||||
entries_with_ids = existing_entries_sorted + new_entries
|
||||
end = time.time()
|
||||
logger.debug(f"Identify, Mark, Combine new, existing entries: {end - start} seconds")
|
||||
|
||||
return entries_with_ids
|
|
@ -1,13 +1,17 @@
|
|||
# Standard Packages
|
||||
import json
|
||||
import gzip
|
||||
import logging
|
||||
|
||||
# Internal Packages
|
||||
from src.utils.constants import empty_escape_sequences
|
||||
from src.utils.helpers import get_absolute_path
|
||||
|
||||
|
||||
def load_jsonl(input_path, verbose=0):
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_jsonl(input_path):
|
||||
"Read List of JSON objects from JSON line file"
|
||||
# Initialize Variables
|
||||
data = []
|
||||
|
@ -27,13 +31,12 @@ def load_jsonl(input_path, verbose=0):
|
|||
jsonl_file.close()
|
||||
|
||||
# Log JSONL entries loaded
|
||||
if verbose > 0:
|
||||
print(f'Loaded {len(data)} records from {input_path}')
|
||||
logger.info(f'Loaded {len(data)} records from {input_path}')
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def dump_jsonl(jsonl_data, output_path, verbose=0):
|
||||
def dump_jsonl(jsonl_data, output_path):
|
||||
"Write List of JSON objects to JSON line file"
|
||||
# Create output directory, if it doesn't exist
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
@ -41,16 +44,14 @@ def dump_jsonl(jsonl_data, output_path, verbose=0):
|
|||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
f.write(jsonl_data)
|
||||
|
||||
if verbose > 0:
|
||||
print(f'Wrote {len(jsonl_data)} lines to jsonl at {output_path}')
|
||||
logger.info(f'Wrote jsonl data to {output_path}')
|
||||
|
||||
|
||||
def compress_jsonl_data(jsonl_data, output_path, verbose=0):
|
||||
def compress_jsonl_data(jsonl_data, output_path):
|
||||
# Create output directory, if it doesn't exist
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with gzip.open(output_path, 'wt') as gzip_file:
|
||||
gzip_file.write(jsonl_data)
|
||||
|
||||
if verbose > 0:
|
||||
print(f'Wrote {len(jsonl_data)} lines to gzip compressed jsonl at {output_path}')
|
||||
logger.info(f'Wrote jsonl data to gzip compressed jsonl at {output_path}')
|
|
@ -6,7 +6,7 @@ from typing import List, Optional
|
|||
from pydantic import BaseModel, validator
|
||||
|
||||
# Internal Packages
|
||||
from src.utils.helpers import to_snake_case_from_dash
|
||||
from src.utils.helpers import to_snake_case_from_dash, is_none_or_empty
|
||||
|
||||
class ConfigBase(BaseModel):
|
||||
class Config:
|
||||
|
@ -15,26 +15,27 @@ class ConfigBase(BaseModel):
|
|||
|
||||
class TextContentConfig(ConfigBase):
|
||||
input_files: Optional[List[Path]]
|
||||
input_filter: Optional[str]
|
||||
input_filter: Optional[List[str]]
|
||||
compressed_jsonl: Path
|
||||
embeddings_file: Path
|
||||
index_heading_entries: Optional[bool] = False
|
||||
|
||||
@validator('input_filter')
|
||||
def input_filter_or_files_required(cls, input_filter, values, **kwargs):
|
||||
if input_filter is None and ('input_files' not in values or values["input_files"] is None):
|
||||
if is_none_or_empty(input_filter) and ('input_files' not in values or values["input_files"] is None):
|
||||
raise ValueError("Either input_filter or input_files required in all content-type.<text_search> section of Khoj config file")
|
||||
return input_filter
|
||||
|
||||
class ImageContentConfig(ConfigBase):
|
||||
input_directories: Optional[List[Path]]
|
||||
input_filter: Optional[str]
|
||||
input_filter: Optional[List[str]]
|
||||
embeddings_file: Path
|
||||
use_xmp_metadata: bool
|
||||
batch_size: int
|
||||
|
||||
@validator('input_filter')
|
||||
def input_filter_or_directories_required(cls, input_filter, values, **kwargs):
|
||||
if input_filter is None and ('input_directories' not in values or values["input_directories"] is None):
|
||||
if is_none_or_empty(input_filter) and ('input_directories' not in values or values["input_directories"] is None):
|
||||
raise ValueError("Either input_filter or input_directories required in all content-type.image section of Khoj config file")
|
||||
return input_filter
|
||||
|
||||
|
|
|
@ -1,22 +1,25 @@
|
|||
# Standard Packages
|
||||
from packaging import version
|
||||
|
||||
# External Packages
|
||||
import torch
|
||||
from pathlib import Path
|
||||
|
||||
# Internal Packages
|
||||
from src.utils.config import SearchModels, ProcessorConfigModel
|
||||
from src.utils.helpers import LRU
|
||||
from src.utils.rawconfig import FullConfig
|
||||
|
||||
# Application Global State
|
||||
config = FullConfig()
|
||||
model = SearchModels()
|
||||
processor_config = ProcessorConfigModel()
|
||||
config_file: Path = ""
|
||||
config_file: Path = None
|
||||
verbose: int = 0
|
||||
host: str = None
|
||||
port: int = None
|
||||
cli_args = None
|
||||
cli_args: list[str] = None
|
||||
query_cache = LRU()
|
||||
|
||||
if torch.cuda.is_available():
|
||||
# Use CUDA GPU
|
||||
|
|
|
@ -5,12 +5,13 @@ from pathlib import Path
|
|||
import yaml
|
||||
|
||||
# Internal Packages
|
||||
from src.utils.helpers import get_absolute_path, resolve_absolute_path
|
||||
from src.utils.rawconfig import FullConfig
|
||||
|
||||
|
||||
# Do not emit tags when dumping to YAML
|
||||
yaml.emitter.Emitter.process_tag = lambda self, *args, **kwargs: None
|
||||
|
||||
|
||||
def save_config_to_file(yaml_config: dict, yaml_config_file: Path):
|
||||
"Write config to YML file"
|
||||
# Create output directory, if it doesn't exist
|
||||
|
|
|
@ -1,78 +1,65 @@
|
|||
# Standard Packages
|
||||
# External Packages
|
||||
import pytest
|
||||
|
||||
# Internal Packages
|
||||
from src.search_type import image_search, text_search
|
||||
from src.utils.config import SearchType
|
||||
from src.utils.helpers import resolve_absolute_path
|
||||
from src.utils.rawconfig import ContentConfig, TextContentConfig, ImageContentConfig, SearchConfig, TextSearchConfig, ImageSearchConfig
|
||||
from src.processor.org_mode.org_to_jsonl import org_to_jsonl
|
||||
from src.utils import state
|
||||
from src.search_filter.date_filter import DateFilter
|
||||
from src.search_filter.word_filter import WordFilter
|
||||
from src.search_filter.file_filter import FileFilter
|
||||
|
||||
|
||||
@pytest.fixture(scope='session')
|
||||
def search_config(tmp_path_factory):
|
||||
model_dir = tmp_path_factory.mktemp('data')
|
||||
|
||||
def search_config() -> SearchConfig:
|
||||
model_dir = resolve_absolute_path('~/.khoj/search')
|
||||
model_dir.mkdir(parents=True, exist_ok=True)
|
||||
search_config = SearchConfig()
|
||||
|
||||
search_config.symmetric = TextSearchConfig(
|
||||
encoder = "sentence-transformers/all-MiniLM-L6-v2",
|
||||
cross_encoder = "cross-encoder/ms-marco-MiniLM-L-6-v2",
|
||||
model_directory = model_dir
|
||||
model_directory = model_dir / 'symmetric/'
|
||||
)
|
||||
|
||||
search_config.asymmetric = TextSearchConfig(
|
||||
encoder = "sentence-transformers/multi-qa-MiniLM-L6-cos-v1",
|
||||
cross_encoder = "cross-encoder/ms-marco-MiniLM-L-6-v2",
|
||||
model_directory = model_dir
|
||||
model_directory = model_dir / 'asymmetric/'
|
||||
)
|
||||
|
||||
search_config.image = ImageSearchConfig(
|
||||
encoder = "sentence-transformers/clip-ViT-B-32",
|
||||
model_directory = model_dir
|
||||
model_directory = model_dir / 'image/'
|
||||
)
|
||||
|
||||
return search_config
|
||||
|
||||
|
||||
@pytest.fixture(scope='session')
|
||||
def model_dir(search_config):
|
||||
model_dir = search_config.asymmetric.model_directory
|
||||
def content_config(tmp_path_factory, search_config: SearchConfig):
|
||||
content_dir = tmp_path_factory.mktemp('content')
|
||||
|
||||
# Generate Image Embeddings from Test Images
|
||||
content_config = ContentConfig()
|
||||
content_config.image = ImageContentConfig(
|
||||
input_directories = ['tests/data/images'],
|
||||
embeddings_file = model_dir.joinpath('image_embeddings.pt'),
|
||||
batch_size = 10,
|
||||
embeddings_file = content_dir.joinpath('image_embeddings.pt'),
|
||||
batch_size = 1,
|
||||
use_xmp_metadata = False)
|
||||
|
||||
image_search.setup(content_config.image, search_config.image, regenerate=False, verbose=True)
|
||||
image_search.setup(content_config.image, search_config.image, regenerate=False)
|
||||
|
||||
# Generate Notes Embeddings from Test Notes
|
||||
content_config.org = TextContentConfig(
|
||||
input_files = None,
|
||||
input_filter = 'tests/data/org/*.org',
|
||||
compressed_jsonl = model_dir.joinpath('notes.jsonl.gz'),
|
||||
embeddings_file = model_dir.joinpath('note_embeddings.pt'))
|
||||
input_filter = ['tests/data/org/*.org'],
|
||||
compressed_jsonl = content_dir.joinpath('notes.jsonl.gz'),
|
||||
embeddings_file = content_dir.joinpath('note_embeddings.pt'))
|
||||
|
||||
text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False, verbose=True)
|
||||
|
||||
return model_dir
|
||||
|
||||
|
||||
@pytest.fixture(scope='session')
|
||||
def content_config(model_dir):
|
||||
content_config = ContentConfig()
|
||||
content_config.org = TextContentConfig(
|
||||
input_files = None,
|
||||
input_filter = 'tests/data/org/*.org',
|
||||
compressed_jsonl = model_dir.joinpath('notes.jsonl.gz'),
|
||||
embeddings_file = model_dir.joinpath('note_embeddings.pt'))
|
||||
|
||||
content_config.image = ImageContentConfig(
|
||||
input_directories = ['tests/data/images'],
|
||||
embeddings_file = model_dir.joinpath('image_embeddings.pt'),
|
||||
batch_size = 1,
|
||||
use_xmp_metadata = False)
|
||||
filters = [DateFilter(), WordFilter(), FileFilter()]
|
||||
text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters)
|
||||
|
||||
return content_config
|
|
@ -1,9 +1,10 @@
|
|||
content-type:
|
||||
org:
|
||||
input-files: [ "~/first_from_config.org", "~/second_from_config.org" ]
|
||||
input-filter: "*.org"
|
||||
input-filter: ["*.org", "~/notes/*.org"]
|
||||
compressed-jsonl: ".notes.json.gz"
|
||||
embeddings-file: ".note_embeddings.pt"
|
||||
index-header-entries: true
|
||||
|
||||
search-type:
|
||||
asymmetric:
|
||||
|
|
112
tests/test_beancount_to_jsonl.py
Normal file
112
tests/test_beancount_to_jsonl.py
Normal file
|
@ -0,0 +1,112 @@
|
|||
# Standard Packages
|
||||
import json
|
||||
|
||||
# Internal Packages
|
||||
from src.processor.ledger.beancount_to_jsonl import extract_beancount_transactions, convert_transactions_to_maps, convert_transaction_maps_to_jsonl, get_beancount_files
|
||||
|
||||
|
||||
def test_no_transactions_in_file(tmp_path):
|
||||
"Handle file with no transactions."
|
||||
# Arrange
|
||||
entry = f'''
|
||||
- Bullet point 1
|
||||
- Bullet point 2
|
||||
'''
|
||||
beancount_file = create_file(tmp_path, entry)
|
||||
|
||||
# Act
|
||||
# Extract Entries from specified Beancount files
|
||||
entry_nodes, file_to_entries = extract_beancount_transactions(beancount_files=[beancount_file])
|
||||
|
||||
# Process Each Entry from All Beancount Files
|
||||
jsonl_string = convert_transaction_maps_to_jsonl(convert_transactions_to_maps(entry_nodes, file_to_entries))
|
||||
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
|
||||
|
||||
# Assert
|
||||
assert len(jsonl_data) == 0
|
||||
|
||||
|
||||
def test_single_beancount_transaction_to_jsonl(tmp_path):
|
||||
"Convert transaction from single file to jsonl."
|
||||
# Arrange
|
||||
entry = f'''
|
||||
1984-04-01 * "Payee" "Narration"
|
||||
Expenses:Test:Test 1.00 KES
|
||||
Assets:Test:Test -1.00 KES
|
||||
'''
|
||||
beancount_file = create_file(tmp_path, entry)
|
||||
|
||||
# Act
|
||||
# Extract Entries from specified Beancount files
|
||||
entries, entry_to_file_map = extract_beancount_transactions(beancount_files=[beancount_file])
|
||||
|
||||
# Process Each Entry from All Beancount Files
|
||||
jsonl_string = convert_transaction_maps_to_jsonl(convert_transactions_to_maps(entries, entry_to_file_map))
|
||||
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
|
||||
|
||||
# Assert
|
||||
assert len(jsonl_data) == 1
|
||||
|
||||
|
||||
def test_multiple_transactions_to_jsonl(tmp_path):
|
||||
"Convert multiple transactions from single file to jsonl."
|
||||
# Arrange
|
||||
entry = f'''
|
||||
1984-04-01 * "Payee" "Narration"
|
||||
Expenses:Test:Test 1.00 KES
|
||||
Assets:Test:Test -1.00 KES
|
||||
\t\r
|
||||
1984-04-01 * "Payee" "Narration"
|
||||
Expenses:Test:Test 1.00 KES
|
||||
Assets:Test:Test -1.00 KES
|
||||
'''
|
||||
|
||||
beancount_file = create_file(tmp_path, entry)
|
||||
|
||||
# Act
|
||||
# Extract Entries from specified Beancount files
|
||||
entries, entry_to_file_map = extract_beancount_transactions(beancount_files=[beancount_file])
|
||||
|
||||
# Process Each Entry from All Beancount Files
|
||||
jsonl_string = convert_transaction_maps_to_jsonl(convert_transactions_to_maps(entries, entry_to_file_map))
|
||||
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
|
||||
|
||||
# Assert
|
||||
assert len(jsonl_data) == 2
|
||||
|
||||
|
||||
def test_get_beancount_files(tmp_path):
|
||||
"Ensure Beancount files specified via input-filter, input-files extracted"
|
||||
# Arrange
|
||||
# Include via input-filter globs
|
||||
group1_file1 = create_file(tmp_path, filename="group1-file1.bean")
|
||||
group1_file2 = create_file(tmp_path, filename="group1-file2.bean")
|
||||
group2_file1 = create_file(tmp_path, filename="group2-file1.beancount")
|
||||
group2_file2 = create_file(tmp_path, filename="group2-file2.beancount")
|
||||
# Include via input-file field
|
||||
file1 = create_file(tmp_path, filename="ledger.bean")
|
||||
# Not included by any filter
|
||||
create_file(tmp_path, filename="not-included-ledger.bean")
|
||||
create_file(tmp_path, filename="not-included-text.txt")
|
||||
|
||||
expected_files = sorted(map(str, [group1_file1, group1_file2, group2_file1, group2_file2, file1]))
|
||||
|
||||
# Setup input-files, input-filters
|
||||
input_files = [tmp_path / 'ledger.bean']
|
||||
input_filter = [tmp_path / 'group1*.bean', tmp_path / 'group2*.beancount']
|
||||
|
||||
# Act
|
||||
extracted_org_files = get_beancount_files(input_files, input_filter)
|
||||
|
||||
# Assert
|
||||
assert len(extracted_org_files) == 5
|
||||
assert extracted_org_files == expected_files
|
||||
|
||||
|
||||
# Helper Functions
|
||||
def create_file(tmp_path, entry=None, filename="ledger.beancount"):
|
||||
beancount_file = tmp_path / filename
|
||||
beancount_file.touch()
|
||||
if entry:
|
||||
beancount_file.write_text(entry)
|
||||
return beancount_file
|
|
@ -2,9 +2,6 @@
|
|||
from pathlib import Path
|
||||
from random import random
|
||||
|
||||
# External Modules
|
||||
import pytest
|
||||
|
||||
# Internal Packages
|
||||
from src.utils.cli import cli
|
||||
from src.utils.helpers import resolve_absolute_path
|
||||
|
|
|
@ -1,17 +1,20 @@
|
|||
# Standard Modules
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
from urllib.parse import quote
|
||||
|
||||
|
||||
# External Packages
|
||||
from fastapi.testclient import TestClient
|
||||
import pytest
|
||||
|
||||
# Internal Packages
|
||||
from src.main import app
|
||||
from src.utils.state import model, config
|
||||
from src.search_type import text_search, image_search
|
||||
from src.utils.rawconfig import ContentConfig, SearchConfig
|
||||
from src.processor.org_mode import org_to_jsonl
|
||||
from src.processor.org_mode.org_to_jsonl import org_to_jsonl
|
||||
from src.search_filter.word_filter import WordFilter
|
||||
from src.search_filter.file_filter import FileFilter
|
||||
|
||||
|
||||
# Arrange
|
||||
|
@ -22,7 +25,7 @@ client = TestClient(app)
|
|||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_search_with_invalid_content_type():
|
||||
# Arrange
|
||||
user_query = "How to call Khoj from Emacs?"
|
||||
user_query = quote("How to call Khoj from Emacs?")
|
||||
|
||||
# Act
|
||||
response = client.get(f"/search?q={user_query}&t=invalid_content_type")
|
||||
|
@ -116,7 +119,7 @@ def test_image_search(content_config: ContentConfig, search_config: SearchConfig
|
|||
def test_notes_search(content_config: ContentConfig, search_config: SearchConfig):
|
||||
# Arrange
|
||||
model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False)
|
||||
user_query = "How to git install application?"
|
||||
user_query = quote("How to git install application?")
|
||||
|
||||
# Act
|
||||
response = client.get(f"/search?q={user_query}&n=1&t=org&r=true")
|
||||
|
@ -129,17 +132,35 @@ def test_notes_search(content_config: ContentConfig, search_config: SearchConfig
|
|||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_notes_search_with_include_filter(content_config: ContentConfig, search_config: SearchConfig):
|
||||
def test_notes_search_with_only_filters(content_config: ContentConfig, search_config: SearchConfig):
|
||||
# Arrange
|
||||
model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False)
|
||||
user_query = "How to git install application? +Emacs"
|
||||
filters = [WordFilter(), FileFilter()]
|
||||
model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters)
|
||||
user_query = quote('+"Emacs" file:"*.org"')
|
||||
|
||||
# Act
|
||||
response = client.get(f"/search?q={user_query}&n=1&t=org")
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
# assert actual_data contains explicitly included word "Emacs"
|
||||
# assert actual_data contains word "Emacs"
|
||||
search_result = response.json()[0]["entry"]
|
||||
assert "Emacs" in search_result
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_notes_search_with_include_filter(content_config: ContentConfig, search_config: SearchConfig):
|
||||
# Arrange
|
||||
filters = [WordFilter()]
|
||||
model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters)
|
||||
user_query = quote('How to git install application? +"Emacs"')
|
||||
|
||||
# Act
|
||||
response = client.get(f"/search?q={user_query}&n=1&t=org")
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
# assert actual_data contains word "Emacs"
|
||||
search_result = response.json()[0]["entry"]
|
||||
assert "Emacs" in search_result
|
||||
|
||||
|
@ -147,14 +168,15 @@ def test_notes_search_with_include_filter(content_config: ContentConfig, search_
|
|||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_notes_search_with_exclude_filter(content_config: ContentConfig, search_config: SearchConfig):
|
||||
# Arrange
|
||||
model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False)
|
||||
user_query = "How to git install application? -clone"
|
||||
filters = [WordFilter()]
|
||||
model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters)
|
||||
user_query = quote('How to git install application? -"clone"')
|
||||
|
||||
# Act
|
||||
response = client.get(f"/search?q={user_query}&n=1&t=org")
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
# assert actual_data does not contains explicitly excluded word "Emacs"
|
||||
# assert actual_data does not contains word "Emacs"
|
||||
search_result = response.json()[0]["entry"]
|
||||
assert "clone" not in search_result
|
||||
|
|
|
@ -18,40 +18,34 @@ def test_date_filter():
|
|||
{'compiled': '', 'raw': 'Entry with date:1984-04-02'}]
|
||||
|
||||
q_with_no_date_filter = 'head tail'
|
||||
ret_query, ret_entries, ret_emb = DateFilter().filter(q_with_no_date_filter, entries.copy(), embeddings)
|
||||
ret_query, entry_indices = DateFilter().apply(q_with_no_date_filter, entries)
|
||||
assert ret_query == 'head tail'
|
||||
assert len(ret_emb) == 3
|
||||
assert ret_entries == entries
|
||||
assert entry_indices == {0, 1, 2}
|
||||
|
||||
q_with_dtrange_non_overlapping_at_boundary = 'head dt>"1984-04-01" dt<"1984-04-02" tail'
|
||||
ret_query, ret_entries, ret_emb = DateFilter().filter(q_with_dtrange_non_overlapping_at_boundary, entries.copy(), embeddings)
|
||||
ret_query, entry_indices = DateFilter().apply(q_with_dtrange_non_overlapping_at_boundary, entries)
|
||||
assert ret_query == 'head tail'
|
||||
assert len(ret_emb) == 0
|
||||
assert ret_entries == []
|
||||
assert entry_indices == set()
|
||||
|
||||
query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<"1984-04-03" tail'
|
||||
ret_query, ret_entries, ret_emb = DateFilter().filter(query_with_overlapping_dtrange, entries.copy(), embeddings)
|
||||
ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries)
|
||||
assert ret_query == 'head tail'
|
||||
assert ret_entries == [entries[2]]
|
||||
assert len(ret_emb) == 1
|
||||
assert entry_indices == {2}
|
||||
|
||||
query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<"1984-04-02" tail'
|
||||
ret_query, ret_entries, ret_emb = DateFilter().filter(query_with_overlapping_dtrange, entries.copy(), embeddings)
|
||||
ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries)
|
||||
assert ret_query == 'head tail'
|
||||
assert ret_entries == [entries[1]]
|
||||
assert len(ret_emb) == 1
|
||||
assert entry_indices == {1}
|
||||
|
||||
query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<="1984-04-02" tail'
|
||||
ret_query, ret_entries, ret_emb = DateFilter().filter(query_with_overlapping_dtrange, entries.copy(), embeddings)
|
||||
ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries)
|
||||
assert ret_query == 'head tail'
|
||||
assert ret_entries == [entries[2]]
|
||||
assert len(ret_emb) == 1
|
||||
assert entry_indices == {2}
|
||||
|
||||
query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<="1984-04-02" tail'
|
||||
ret_query, ret_entries, ret_emb = DateFilter().filter(query_with_overlapping_dtrange, entries.copy(), embeddings)
|
||||
ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries)
|
||||
assert ret_query == 'head tail'
|
||||
assert ret_entries == [entries[1], entries[2]]
|
||||
assert len(ret_emb) == 2
|
||||
assert entry_indices == {1, 2}
|
||||
|
||||
|
||||
def test_extract_date_range():
|
||||
|
|
112
tests/test_file_filter.py
Normal file
112
tests/test_file_filter.py
Normal file
|
@ -0,0 +1,112 @@
|
|||
# External Packages
|
||||
import torch
|
||||
|
||||
# Application Packages
|
||||
from src.search_filter.file_filter import FileFilter
|
||||
|
||||
|
||||
def test_no_file_filter():
|
||||
# Arrange
|
||||
file_filter = FileFilter()
|
||||
embeddings, entries = arrange_content()
|
||||
q_with_no_filter = 'head tail'
|
||||
|
||||
# Act
|
||||
can_filter = file_filter.can_filter(q_with_no_filter)
|
||||
ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries)
|
||||
|
||||
# Assert
|
||||
assert can_filter == False
|
||||
assert ret_query == 'head tail'
|
||||
assert entry_indices == {0, 1, 2, 3}
|
||||
|
||||
|
||||
def test_file_filter_with_non_existent_file():
|
||||
# Arrange
|
||||
file_filter = FileFilter()
|
||||
embeddings, entries = arrange_content()
|
||||
q_with_no_filter = 'head file:"nonexistent.org" tail'
|
||||
|
||||
# Act
|
||||
can_filter = file_filter.can_filter(q_with_no_filter)
|
||||
ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries)
|
||||
|
||||
# Assert
|
||||
assert can_filter == True
|
||||
assert ret_query == 'head tail'
|
||||
assert entry_indices == {}
|
||||
|
||||
|
||||
def test_single_file_filter():
|
||||
# Arrange
|
||||
file_filter = FileFilter()
|
||||
embeddings, entries = arrange_content()
|
||||
q_with_no_filter = 'head file:"file 1.org" tail'
|
||||
|
||||
# Act
|
||||
can_filter = file_filter.can_filter(q_with_no_filter)
|
||||
ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries)
|
||||
|
||||
# Assert
|
||||
assert can_filter == True
|
||||
assert ret_query == 'head tail'
|
||||
assert entry_indices == {0, 2}
|
||||
|
||||
|
||||
def test_file_filter_with_partial_match():
|
||||
# Arrange
|
||||
file_filter = FileFilter()
|
||||
embeddings, entries = arrange_content()
|
||||
q_with_no_filter = 'head file:"1.org" tail'
|
||||
|
||||
# Act
|
||||
can_filter = file_filter.can_filter(q_with_no_filter)
|
||||
ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries)
|
||||
|
||||
# Assert
|
||||
assert can_filter == True
|
||||
assert ret_query == 'head tail'
|
||||
assert entry_indices == {0, 2}
|
||||
|
||||
|
||||
def test_file_filter_with_regex_match():
|
||||
# Arrange
|
||||
file_filter = FileFilter()
|
||||
embeddings, entries = arrange_content()
|
||||
q_with_no_filter = 'head file:"*.org" tail'
|
||||
|
||||
# Act
|
||||
can_filter = file_filter.can_filter(q_with_no_filter)
|
||||
ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries)
|
||||
|
||||
# Assert
|
||||
assert can_filter == True
|
||||
assert ret_query == 'head tail'
|
||||
assert entry_indices == {0, 1, 2, 3}
|
||||
|
||||
|
||||
def test_multiple_file_filter():
|
||||
# Arrange
|
||||
file_filter = FileFilter()
|
||||
embeddings, entries = arrange_content()
|
||||
q_with_no_filter = 'head tail file:"file 1.org" file:"file2.org"'
|
||||
|
||||
# Act
|
||||
can_filter = file_filter.can_filter(q_with_no_filter)
|
||||
ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries)
|
||||
|
||||
# Assert
|
||||
assert can_filter == True
|
||||
assert ret_query == 'head tail'
|
||||
assert entry_indices == {0, 1, 2, 3}
|
||||
|
||||
|
||||
def arrange_content():
|
||||
embeddings = torch.randn(4, 10)
|
||||
entries = [
|
||||
{'compiled': '', 'raw': 'First Entry', 'file': 'file 1.org'},
|
||||
{'compiled': '', 'raw': 'Second Entry', 'file': 'file2.org'},
|
||||
{'compiled': '', 'raw': 'Third Entry', 'file': 'file 1.org'},
|
||||
{'compiled': '', 'raw': 'Fourth Entry', 'file': 'file2.org'}]
|
||||
|
||||
return embeddings, entries
|
|
@ -28,3 +28,18 @@ def test_merge_dicts():
|
|||
|
||||
# do not override existing key in priority_dict with default dict
|
||||
assert helpers.merge_dicts(priority_dict={'a': 1}, default_dict={'a': 2}) == {'a': 1}
|
||||
|
||||
|
||||
def test_lru_cache():
|
||||
# Test initializing cache
|
||||
cache = helpers.LRU({'a': 1, 'b': 2}, capacity=2)
|
||||
assert cache == {'a': 1, 'b': 2}
|
||||
|
||||
# Test capacity overflow
|
||||
cache['c'] = 3
|
||||
assert cache == {'b': 2, 'c': 3}
|
||||
|
||||
# Test delete least recently used item from LRU cache on capacity overflow
|
||||
cache['b'] # accessing 'b' makes it the most recently used item
|
||||
cache['d'] = 4 # so 'c' is deleted from the cache instead of 'b'
|
||||
assert cache == {'b': 2, 'd': 4}
|
||||
|
|
|
@ -48,8 +48,13 @@ def test_image_search(content_config: ContentConfig, search_config: SearchConfig
|
|||
image_files_url='/static/images',
|
||||
count=1)
|
||||
|
||||
actual_image = Image.open(output_directory.joinpath(Path(results[0]["entry"]).name))
|
||||
actual_image_path = output_directory.joinpath(Path(results[0]["entry"]).name)
|
||||
actual_image = Image.open(actual_image_path)
|
||||
expected_image = Image.open(content_config.image.input_directories[0].joinpath(expected_image_name))
|
||||
|
||||
# Assert
|
||||
assert expected_image == actual_image
|
||||
|
||||
# Cleanup
|
||||
# Delete the image files copied to results directory
|
||||
actual_image_path.unlink()
|
||||
|
|
109
tests/test_markdown_to_jsonl.py
Normal file
109
tests/test_markdown_to_jsonl.py
Normal file
|
@ -0,0 +1,109 @@
|
|||
# Standard Packages
|
||||
import json
|
||||
|
||||
# Internal Packages
|
||||
from src.processor.markdown.markdown_to_jsonl import extract_markdown_entries, convert_markdown_maps_to_jsonl, convert_markdown_entries_to_maps, get_markdown_files
|
||||
|
||||
|
||||
def test_markdown_file_with_no_headings_to_jsonl(tmp_path):
|
||||
"Convert files with no heading to jsonl."
|
||||
# Arrange
|
||||
entry = f'''
|
||||
- Bullet point 1
|
||||
- Bullet point 2
|
||||
'''
|
||||
markdownfile = create_file(tmp_path, entry)
|
||||
|
||||
# Act
|
||||
# Extract Entries from specified Markdown files
|
||||
entry_nodes, file_to_entries = extract_markdown_entries(markdown_files=[markdownfile])
|
||||
|
||||
# Process Each Entry from All Notes Files
|
||||
jsonl_string = convert_markdown_maps_to_jsonl(convert_markdown_entries_to_maps(entry_nodes, file_to_entries))
|
||||
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
|
||||
|
||||
# Assert
|
||||
assert len(jsonl_data) == 1
|
||||
|
||||
|
||||
def test_single_markdown_entry_to_jsonl(tmp_path):
|
||||
"Convert markdown entry from single file to jsonl."
|
||||
# Arrange
|
||||
entry = f'''### Heading
|
||||
\t\r
|
||||
Body Line 1
|
||||
'''
|
||||
markdownfile = create_file(tmp_path, entry)
|
||||
|
||||
# Act
|
||||
# Extract Entries from specified Markdown files
|
||||
entries, entry_to_file_map = extract_markdown_entries(markdown_files=[markdownfile])
|
||||
|
||||
# Process Each Entry from All Notes Files
|
||||
jsonl_string = convert_markdown_maps_to_jsonl(convert_markdown_entries_to_maps(entries, entry_to_file_map))
|
||||
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
|
||||
|
||||
# Assert
|
||||
assert len(jsonl_data) == 1
|
||||
|
||||
|
||||
def test_multiple_markdown_entries_to_jsonl(tmp_path):
|
||||
"Convert multiple markdown entries from single file to jsonl."
|
||||
# Arrange
|
||||
entry = f'''
|
||||
### Heading 1
|
||||
\t\r
|
||||
Heading 1 Body Line 1
|
||||
### Heading 2
|
||||
\t\r
|
||||
Heading 2 Body Line 2
|
||||
'''
|
||||
markdownfile = create_file(tmp_path, entry)
|
||||
|
||||
# Act
|
||||
# Extract Entries from specified Markdown files
|
||||
entries, entry_to_file_map = extract_markdown_entries(markdown_files=[markdownfile])
|
||||
|
||||
# Process Each Entry from All Notes Files
|
||||
jsonl_string = convert_markdown_maps_to_jsonl(convert_markdown_entries_to_maps(entries, entry_to_file_map))
|
||||
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
|
||||
|
||||
# Assert
|
||||
assert len(jsonl_data) == 2
|
||||
|
||||
|
||||
def test_get_markdown_files(tmp_path):
|
||||
"Ensure Markdown files specified via input-filter, input-files extracted"
|
||||
# Arrange
|
||||
# Include via input-filter globs
|
||||
group1_file1 = create_file(tmp_path, filename="group1-file1.md")
|
||||
group1_file2 = create_file(tmp_path, filename="group1-file2.md")
|
||||
group2_file1 = create_file(tmp_path, filename="group2-file1.markdown")
|
||||
group2_file2 = create_file(tmp_path, filename="group2-file2.markdown")
|
||||
# Include via input-file field
|
||||
file1 = create_file(tmp_path, filename="notes.md")
|
||||
# Not included by any filter
|
||||
create_file(tmp_path, filename="not-included-markdown.md")
|
||||
create_file(tmp_path, filename="not-included-text.txt")
|
||||
|
||||
expected_files = sorted(map(str, [group1_file1, group1_file2, group2_file1, group2_file2, file1]))
|
||||
|
||||
# Setup input-files, input-filters
|
||||
input_files = [tmp_path / 'notes.md']
|
||||
input_filter = [tmp_path / 'group1*.md', tmp_path / 'group2*.markdown']
|
||||
|
||||
# Act
|
||||
extracted_org_files = get_markdown_files(input_files, input_filter)
|
||||
|
||||
# Assert
|
||||
assert len(extracted_org_files) == 5
|
||||
assert extracted_org_files == expected_files
|
||||
|
||||
|
||||
# Helper Functions
|
||||
def create_file(tmp_path, entry=None, filename="test.md"):
|
||||
markdown_file = tmp_path / filename
|
||||
markdown_file.touch()
|
||||
if entry:
|
||||
markdown_file.write_text(entry)
|
||||
return markdown_file
|
|
@ -1,32 +1,37 @@
|
|||
# Standard Packages
|
||||
import json
|
||||
from posixpath import split
|
||||
|
||||
# Internal Packages
|
||||
from src.processor.org_mode.org_to_jsonl import convert_org_entries_to_jsonl, extract_org_entries
|
||||
from src.processor.org_mode.org_to_jsonl import convert_org_entries_to_jsonl, convert_org_nodes_to_entries, extract_org_entries, get_org_files
|
||||
from src.utils.helpers import is_none_or_empty
|
||||
|
||||
|
||||
def test_entry_with_empty_body_line_to_jsonl(tmp_path):
|
||||
'''Ensure entries with empty body are ignored.
|
||||
def test_configure_heading_entry_to_jsonl(tmp_path):
|
||||
'''Ensure entries with empty body are ignored, unless explicitly configured to index heading entries.
|
||||
Property drawers not considered Body. Ignore control characters for evaluating if Body empty.'''
|
||||
# Arrange
|
||||
entry = f'''*** Heading
|
||||
:PROPERTIES:
|
||||
:ID: 42-42-42
|
||||
:END:
|
||||
\t\r\n
|
||||
\t \r
|
||||
'''
|
||||
orgfile = create_file(tmp_path, entry)
|
||||
|
||||
for index_heading_entries in [True, False]:
|
||||
# Act
|
||||
# Extract Entries from specified Org files
|
||||
entries = extract_org_entries(org_files=[orgfile])
|
||||
|
||||
# Process Each Entry from All Notes Files
|
||||
jsonl_data = convert_org_entries_to_jsonl(entries)
|
||||
# Extract entries into jsonl from specified Org files
|
||||
jsonl_string = convert_org_entries_to_jsonl(convert_org_nodes_to_entries(
|
||||
*extract_org_entries(org_files=[orgfile]),
|
||||
index_heading_entries=index_heading_entries))
|
||||
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
|
||||
|
||||
# Assert
|
||||
if index_heading_entries:
|
||||
# Entry with empty body indexed when index_heading_entries set to True
|
||||
assert len(jsonl_data) == 1
|
||||
else:
|
||||
# Entry with empty body ignored when index_heading_entries set to False
|
||||
assert is_none_or_empty(jsonl_data)
|
||||
|
||||
|
||||
|
@ -37,15 +42,38 @@ def test_entry_with_body_to_jsonl(tmp_path):
|
|||
:PROPERTIES:
|
||||
:ID: 42-42-42
|
||||
:END:
|
||||
\t\r\nBody Line 1\n
|
||||
\t\r
|
||||
Body Line 1
|
||||
'''
|
||||
orgfile = create_file(tmp_path, entry)
|
||||
|
||||
# Act
|
||||
# Extract Entries from specified Org files
|
||||
entries = extract_org_entries(org_files=[orgfile])
|
||||
entries, entry_to_file_map = extract_org_entries(org_files=[orgfile])
|
||||
|
||||
# Process Each Entry from All Notes Files
|
||||
jsonl_string = convert_org_entries_to_jsonl(convert_org_nodes_to_entries(entries, entry_to_file_map))
|
||||
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
|
||||
|
||||
# Assert
|
||||
assert len(jsonl_data) == 1
|
||||
|
||||
|
||||
def test_file_with_no_headings_to_jsonl(tmp_path):
|
||||
"Ensure files with no heading, only body text are loaded."
|
||||
# Arrange
|
||||
entry = f'''
|
||||
- Bullet point 1
|
||||
- Bullet point 2
|
||||
'''
|
||||
orgfile = create_file(tmp_path, entry)
|
||||
|
||||
# Act
|
||||
# Extract Entries from specified Org files
|
||||
entry_nodes, file_to_entries = extract_org_entries(org_files=[orgfile])
|
||||
|
||||
# Process Each Entry from All Notes Files
|
||||
entries = convert_org_nodes_to_entries(entry_nodes, file_to_entries)
|
||||
jsonl_string = convert_org_entries_to_jsonl(entries)
|
||||
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
|
||||
|
||||
|
@ -53,10 +81,38 @@ def test_entry_with_body_to_jsonl(tmp_path):
|
|||
assert len(jsonl_data) == 1
|
||||
|
||||
|
||||
def test_get_org_files(tmp_path):
|
||||
"Ensure Org files specified via input-filter, input-files extracted"
|
||||
# Arrange
|
||||
# Include via input-filter globs
|
||||
group1_file1 = create_file(tmp_path, filename="group1-file1.org")
|
||||
group1_file2 = create_file(tmp_path, filename="group1-file2.org")
|
||||
group2_file1 = create_file(tmp_path, filename="group2-file1.org")
|
||||
group2_file2 = create_file(tmp_path, filename="group2-file2.org")
|
||||
# Include via input-file field
|
||||
orgfile1 = create_file(tmp_path, filename="orgfile1.org")
|
||||
# Not included by any filter
|
||||
create_file(tmp_path, filename="orgfile2.org")
|
||||
create_file(tmp_path, filename="text1.txt")
|
||||
|
||||
expected_files = sorted(map(str, [group1_file1, group1_file2, group2_file1, group2_file2, orgfile1]))
|
||||
|
||||
# Setup input-files, input-filters
|
||||
input_files = [tmp_path / 'orgfile1.org']
|
||||
input_filter = [tmp_path / 'group1*.org', tmp_path / 'group2*.org']
|
||||
|
||||
# Act
|
||||
extracted_org_files = get_org_files(input_files, input_filter)
|
||||
|
||||
# Assert
|
||||
assert len(extracted_org_files) == 5
|
||||
assert extracted_org_files == expected_files
|
||||
|
||||
|
||||
# Helper Functions
|
||||
def create_file(tmp_path, entry, filename="test.org"):
|
||||
org_file = tmp_path / f"notes/{filename}"
|
||||
org_file.parent.mkdir()
|
||||
def create_file(tmp_path, entry=None, filename="test.org"):
|
||||
org_file = tmp_path / filename
|
||||
org_file.touch()
|
||||
if entry:
|
||||
org_file.write_text(entry)
|
||||
return org_file
|
|
@ -1,13 +1,33 @@
|
|||
# Standard Packages
|
||||
import datetime
|
||||
from os.path import relpath
|
||||
from pathlib import Path
|
||||
|
||||
# Internal Packages
|
||||
from src.processor.org_mode import orgnode
|
||||
|
||||
|
||||
# Test
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_parse_entry_with_no_headings(tmp_path):
|
||||
"Test parsing of entry with minimal fields"
|
||||
# Arrange
|
||||
entry = f'''Body Line 1'''
|
||||
orgfile = create_file(tmp_path, entry)
|
||||
|
||||
# Act
|
||||
entries = orgnode.makelist(orgfile)
|
||||
|
||||
# Assert
|
||||
assert len(entries) == 1
|
||||
assert entries[0].heading == f'{orgfile}'
|
||||
assert entries[0].tags == list()
|
||||
assert entries[0].body == "Body Line 1"
|
||||
assert entries[0].priority == ""
|
||||
assert entries[0].Property("ID") == ""
|
||||
assert entries[0].closed == ""
|
||||
assert entries[0].scheduled == ""
|
||||
assert entries[0].deadline == ""
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_parse_minimal_entry(tmp_path):
|
||||
"Test parsing of entry with minimal fields"
|
||||
|
@ -22,14 +42,14 @@ Body Line 1'''
|
|||
|
||||
# Assert
|
||||
assert len(entries) == 1
|
||||
assert entries[0].Heading() == "Heading"
|
||||
assert entries[0].Tags() == set()
|
||||
assert entries[0].Body() == "Body Line 1"
|
||||
assert entries[0].Priority() == ""
|
||||
assert entries[0].heading == "Heading"
|
||||
assert entries[0].tags == list()
|
||||
assert entries[0].body == "Body Line 1"
|
||||
assert entries[0].priority == ""
|
||||
assert entries[0].Property("ID") == ""
|
||||
assert entries[0].Closed() == ""
|
||||
assert entries[0].Scheduled() == ""
|
||||
assert entries[0].Deadline() == ""
|
||||
assert entries[0].closed == ""
|
||||
assert entries[0].scheduled == ""
|
||||
assert entries[0].deadline == ""
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
|
@ -55,16 +75,44 @@ Body Line 2'''
|
|||
|
||||
# Assert
|
||||
assert len(entries) == 1
|
||||
assert entries[0].Heading() == "Heading"
|
||||
assert entries[0].Todo() == "DONE"
|
||||
assert entries[0].Tags() == {"Tag1", "TAG2", "tag3"}
|
||||
assert entries[0].Body() == "- Clocked Log 1\nBody Line 1\nBody Line 2"
|
||||
assert entries[0].Priority() == "A"
|
||||
assert entries[0].heading == "Heading"
|
||||
assert entries[0].todo == "DONE"
|
||||
assert entries[0].tags == ["Tag1", "TAG2", "tag3"]
|
||||
assert entries[0].body == "- Clocked Log 1\nBody Line 1\nBody Line 2"
|
||||
assert entries[0].priority == "A"
|
||||
assert entries[0].Property("ID") == "id:123-456-789-4234-1231"
|
||||
assert entries[0].Closed() == datetime.date(1984,4,1)
|
||||
assert entries[0].Scheduled() == datetime.date(1984,4,1)
|
||||
assert entries[0].Deadline() == datetime.date(1984,4,1)
|
||||
assert entries[0].Logbook() == [(datetime.datetime(1984,4,1,9,0,0), datetime.datetime(1984,4,1,12,0,0))]
|
||||
assert entries[0].closed == datetime.date(1984,4,1)
|
||||
assert entries[0].scheduled == datetime.date(1984,4,1)
|
||||
assert entries[0].deadline == datetime.date(1984,4,1)
|
||||
assert entries[0].logbook == [(datetime.datetime(1984,4,1,9,0,0), datetime.datetime(1984,4,1,12,0,0))]
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_render_entry_with_property_drawer_and_empty_body(tmp_path):
|
||||
"Render heading entry with property drawer"
|
||||
# Arrange
|
||||
entry_to_render = f'''
|
||||
*** [#A] Heading1 :tag1:
|
||||
:PROPERTIES:
|
||||
:ID: 111-111-111-1111-1111
|
||||
:END:
|
||||
\t\r \n
|
||||
'''
|
||||
orgfile = create_file(tmp_path, entry_to_render)
|
||||
|
||||
expected_entry = f'''*** [#A] Heading1 :tag1:
|
||||
:PROPERTIES:
|
||||
:LINE: file:{orgfile}::2
|
||||
:ID: id:111-111-111-1111-1111
|
||||
:SOURCE: [[file:{orgfile}::*Heading1]]
|
||||
:END:
|
||||
'''
|
||||
|
||||
# Act
|
||||
parsed_entries = orgnode.makelist(orgfile)
|
||||
|
||||
# Assert
|
||||
assert f'{parsed_entries[0]}' == expected_entry
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
|
@ -81,18 +129,17 @@ Body Line 1
|
|||
Body Line 2
|
||||
'''
|
||||
orgfile = create_file(tmp_path, entry)
|
||||
normalized_orgfile = f'~/{relpath(orgfile, start=Path.home())}'
|
||||
|
||||
# Act
|
||||
entries = orgnode.makelist(orgfile)
|
||||
|
||||
# Assert
|
||||
# SOURCE link rendered with Heading
|
||||
assert f':SOURCE: [[file:{normalized_orgfile}::*{entries[0].Heading()}]]' in f'{entries[0]}'
|
||||
assert f':SOURCE: [[file:{orgfile}::*{entries[0].heading}]]' in f'{entries[0]}'
|
||||
# ID link rendered with ID
|
||||
assert f':ID: id:123-456-789-4234-1231' in f'{entries[0]}'
|
||||
# LINE link rendered with line number
|
||||
assert f':LINE: file:{normalized_orgfile}::2' in f'{entries[0]}'
|
||||
assert f':LINE: file:{orgfile}::2' in f'{entries[0]}'
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
|
@ -113,10 +160,9 @@ Body Line 1'''
|
|||
# Assert
|
||||
assert len(entries) == 1
|
||||
# parsed heading from entry
|
||||
assert entries[0].Heading() == "Heading[1]"
|
||||
assert entries[0].heading == "Heading[1]"
|
||||
# ensure SOURCE link has square brackets in filename, heading escaped in rendered entries
|
||||
normalized_orgfile = f'~/{relpath(orgfile, start=Path.home())}'
|
||||
escaped_orgfile = f'{normalized_orgfile}'.replace("[1]", "\\[1\\]")
|
||||
escaped_orgfile = f'{orgfile}'.replace("[1]", "\\[1\\]")
|
||||
assert f':SOURCE: [[file:{escaped_orgfile}::*Heading\[1\]' in f'{entries[0]}'
|
||||
|
||||
|
||||
|
@ -156,16 +202,86 @@ Body 2
|
|||
# Assert
|
||||
assert len(entries) == 2
|
||||
for index, entry in enumerate(entries):
|
||||
assert entry.Heading() == f"Heading{index+1}"
|
||||
assert entry.Todo() == "FAILED" if index == 0 else "CANCELLED"
|
||||
assert entry.Tags() == {f"tag{index+1}"}
|
||||
assert entry.Body() == f"- Clocked Log {index+1}\nBody {index+1}\n\n"
|
||||
assert entry.Priority() == "A"
|
||||
assert entry.heading == f"Heading{index+1}"
|
||||
assert entry.todo == "FAILED" if index == 0 else "CANCELLED"
|
||||
assert entry.tags == [f"tag{index+1}"]
|
||||
assert entry.body == f"- Clocked Log {index+1}\nBody {index+1}\n\n"
|
||||
assert entry.priority == "A"
|
||||
assert entry.Property("ID") == f"id:123-456-789-4234-000{index+1}"
|
||||
assert entry.Closed() == datetime.date(1984,4,index+1)
|
||||
assert entry.Scheduled() == datetime.date(1984,4,index+1)
|
||||
assert entry.Deadline() == datetime.date(1984,4,index+1)
|
||||
assert entry.Logbook() == [(datetime.datetime(1984,4,index+1,9,0,0), datetime.datetime(1984,4,index+1,12,0,0))]
|
||||
assert entry.closed == datetime.date(1984,4,index+1)
|
||||
assert entry.scheduled == datetime.date(1984,4,index+1)
|
||||
assert entry.deadline == datetime.date(1984,4,index+1)
|
||||
assert entry.logbook == [(datetime.datetime(1984,4,index+1,9,0,0), datetime.datetime(1984,4,index+1,12,0,0))]
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_parse_entry_with_empty_title(tmp_path):
|
||||
"Test parsing of entry with minimal fields"
|
||||
# Arrange
|
||||
entry = f'''#+TITLE:
|
||||
Body Line 1'''
|
||||
orgfile = create_file(tmp_path, entry)
|
||||
|
||||
# Act
|
||||
entries = orgnode.makelist(orgfile)
|
||||
|
||||
# Assert
|
||||
assert len(entries) == 1
|
||||
assert entries[0].heading == f'{orgfile}'
|
||||
assert entries[0].tags == list()
|
||||
assert entries[0].body == "Body Line 1"
|
||||
assert entries[0].priority == ""
|
||||
assert entries[0].Property("ID") == ""
|
||||
assert entries[0].closed == ""
|
||||
assert entries[0].scheduled == ""
|
||||
assert entries[0].deadline == ""
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_parse_entry_with_title_and_no_headings(tmp_path):
|
||||
"Test parsing of entry with minimal fields"
|
||||
# Arrange
|
||||
entry = f'''#+TITLE: test
|
||||
Body Line 1'''
|
||||
orgfile = create_file(tmp_path, entry)
|
||||
|
||||
# Act
|
||||
entries = orgnode.makelist(orgfile)
|
||||
|
||||
# Assert
|
||||
assert len(entries) == 1
|
||||
assert entries[0].heading == 'test'
|
||||
assert entries[0].tags == list()
|
||||
assert entries[0].body == "Body Line 1"
|
||||
assert entries[0].priority == ""
|
||||
assert entries[0].Property("ID") == ""
|
||||
assert entries[0].closed == ""
|
||||
assert entries[0].scheduled == ""
|
||||
assert entries[0].deadline == ""
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_parse_entry_with_multiple_titles_and_no_headings(tmp_path):
|
||||
"Test parsing of entry with minimal fields"
|
||||
# Arrange
|
||||
entry = f'''#+TITLE: title1
|
||||
Body Line 1
|
||||
#+TITLE: title2 '''
|
||||
orgfile = create_file(tmp_path, entry)
|
||||
|
||||
# Act
|
||||
entries = orgnode.makelist(orgfile)
|
||||
|
||||
# Assert
|
||||
assert len(entries) == 1
|
||||
assert entries[0].heading == 'title1 title2'
|
||||
assert entries[0].tags == list()
|
||||
assert entries[0].body == "Body Line 1\n"
|
||||
assert entries[0].priority == ""
|
||||
assert entries[0].Property("ID") == ""
|
||||
assert entries[0].closed == ""
|
||||
assert entries[0].scheduled == ""
|
||||
assert entries[0].deadline == ""
|
||||
|
||||
|
||||
# Helper Functions
|
||||
|
|
|
@ -1,6 +1,10 @@
|
|||
# System Packages
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
# External Packages
|
||||
import pytest
|
||||
|
||||
# Internal Packages
|
||||
from src.utils.state import model
|
||||
from src.search_type import text_search
|
||||
|
@ -9,6 +13,39 @@ from src.processor.org_mode.org_to_jsonl import org_to_jsonl
|
|||
|
||||
|
||||
# Test
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_asymmetric_setup_with_missing_file_raises_error(content_config: ContentConfig, search_config: SearchConfig):
|
||||
# Arrange
|
||||
file_to_index = Path(content_config.org.input_filter[0]).parent / "new_file_to_index.org"
|
||||
new_org_content_config = deepcopy(content_config.org)
|
||||
new_org_content_config.input_files = [f'{file_to_index}']
|
||||
new_org_content_config.input_filter = None
|
||||
|
||||
# Act
|
||||
# Generate notes embeddings during asymmetric setup
|
||||
with pytest.raises(FileNotFoundError):
|
||||
text_search.setup(org_to_jsonl, new_org_content_config, search_config.asymmetric, regenerate=True)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_asymmetric_setup_with_empty_file_raises_error(content_config: ContentConfig, search_config: SearchConfig):
|
||||
# Arrange
|
||||
file_to_index = Path(content_config.org.input_filter[0]).parent / "new_file_to_index.org"
|
||||
file_to_index.touch()
|
||||
new_org_content_config = deepcopy(content_config.org)
|
||||
new_org_content_config.input_files = [f'{file_to_index}']
|
||||
new_org_content_config.input_filter = None
|
||||
|
||||
# Act
|
||||
# Generate notes embeddings during asymmetric setup
|
||||
with pytest.raises(ValueError, match=r'^No valid entries found*'):
|
||||
text_search.setup(org_to_jsonl, new_org_content_config, search_config.asymmetric, regenerate=True)
|
||||
|
||||
# Cleanup
|
||||
# delete created test file
|
||||
file_to_index.unlink()
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_asymmetric_setup(content_config: ContentConfig, search_config: SearchConfig):
|
||||
# Act
|
||||
|
@ -23,7 +60,7 @@ def test_asymmetric_setup(content_config: ContentConfig, search_config: SearchCo
|
|||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_asymmetric_search(content_config: ContentConfig, search_config: SearchConfig):
|
||||
# Arrange
|
||||
model.notes_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False)
|
||||
model.notes_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=True)
|
||||
query = "How to git install application?"
|
||||
|
||||
# Act
|
||||
|
@ -51,7 +88,7 @@ def test_asymmetric_reload(content_config: ContentConfig, search_config: SearchC
|
|||
assert len(initial_notes_model.entries) == 10
|
||||
assert len(initial_notes_model.corpus_embeddings) == 10
|
||||
|
||||
file_to_add_on_reload = Path(content_config.org.input_filter).parent / "reload.org"
|
||||
file_to_add_on_reload = Path(content_config.org.input_filter[0]).parent / "reload.org"
|
||||
content_config.org.input_files = [f'{file_to_add_on_reload}']
|
||||
|
||||
# append Org-Mode Entry to first Org Input File in Config
|
||||
|
@ -77,3 +114,32 @@ def test_asymmetric_reload(content_config: ContentConfig, search_config: SearchC
|
|||
# delete reload test file added
|
||||
content_config.org.input_files = []
|
||||
file_to_add_on_reload.unlink()
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_incremental_update(content_config: ContentConfig, search_config: SearchConfig):
|
||||
# Arrange
|
||||
initial_notes_model = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=True)
|
||||
|
||||
assert len(initial_notes_model.entries) == 10
|
||||
assert len(initial_notes_model.corpus_embeddings) == 10
|
||||
|
||||
file_to_add_on_update = Path(content_config.org.input_filter[0]).parent / "update.org"
|
||||
content_config.org.input_files = [f'{file_to_add_on_update}']
|
||||
|
||||
# append Org-Mode Entry to first Org Input File in Config
|
||||
with open(file_to_add_on_update, "w") as f:
|
||||
f.write("\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n")
|
||||
|
||||
# Act
|
||||
# update embeddings, entries with the newly added note
|
||||
initial_notes_model = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False)
|
||||
|
||||
# verify new entry added in updated embeddings, entries
|
||||
assert len(initial_notes_model.entries) == 11
|
||||
assert len(initial_notes_model.corpus_embeddings) == 11
|
||||
|
||||
# Cleanup
|
||||
# delete file added for update testing
|
||||
content_config.org.input_files = []
|
||||
file_to_add_on_update.unlink()
|
||||
|
|
77
tests/test_word_filter.py
Normal file
77
tests/test_word_filter.py
Normal file
|
@ -0,0 +1,77 @@
|
|||
# Application Packages
|
||||
from src.search_filter.word_filter import WordFilter
|
||||
from src.utils.config import SearchType
|
||||
|
||||
|
||||
def test_no_word_filter():
|
||||
# Arrange
|
||||
word_filter = WordFilter()
|
||||
entries = arrange_content()
|
||||
q_with_no_filter = 'head tail'
|
||||
|
||||
# Act
|
||||
can_filter = word_filter.can_filter(q_with_no_filter)
|
||||
ret_query, entry_indices = word_filter.apply(q_with_no_filter, entries)
|
||||
|
||||
# Assert
|
||||
assert can_filter == False
|
||||
assert ret_query == 'head tail'
|
||||
assert entry_indices == {0, 1, 2, 3}
|
||||
|
||||
|
||||
def test_word_exclude_filter():
|
||||
# Arrange
|
||||
word_filter = WordFilter()
|
||||
entries = arrange_content()
|
||||
q_with_exclude_filter = 'head -"exclude_word" tail'
|
||||
|
||||
# Act
|
||||
can_filter = word_filter.can_filter(q_with_exclude_filter)
|
||||
ret_query, entry_indices = word_filter.apply(q_with_exclude_filter, entries)
|
||||
|
||||
# Assert
|
||||
assert can_filter == True
|
||||
assert ret_query == 'head tail'
|
||||
assert entry_indices == {0, 2}
|
||||
|
||||
|
||||
def test_word_include_filter():
|
||||
# Arrange
|
||||
word_filter = WordFilter()
|
||||
entries = arrange_content()
|
||||
query_with_include_filter = 'head +"include_word" tail'
|
||||
|
||||
# Act
|
||||
can_filter = word_filter.can_filter(query_with_include_filter)
|
||||
ret_query, entry_indices = word_filter.apply(query_with_include_filter, entries)
|
||||
|
||||
# Assert
|
||||
assert can_filter == True
|
||||
assert ret_query == 'head tail'
|
||||
assert entry_indices == {2, 3}
|
||||
|
||||
|
||||
def test_word_include_and_exclude_filter():
|
||||
# Arrange
|
||||
word_filter = WordFilter()
|
||||
entries = arrange_content()
|
||||
query_with_include_and_exclude_filter = 'head +"include_word" -"exclude_word" tail'
|
||||
|
||||
# Act
|
||||
can_filter = word_filter.can_filter(query_with_include_and_exclude_filter)
|
||||
ret_query, entry_indices = word_filter.apply(query_with_include_and_exclude_filter, entries)
|
||||
|
||||
# Assert
|
||||
assert can_filter == True
|
||||
assert ret_query == 'head tail'
|
||||
assert entry_indices == {2}
|
||||
|
||||
|
||||
def arrange_content():
|
||||
entries = [
|
||||
{'compiled': '', 'raw': 'Minimal Entry'},
|
||||
{'compiled': '', 'raw': 'Entry with exclude_word'},
|
||||
{'compiled': '', 'raw': 'Entry with include_word'},
|
||||
{'compiled': '', 'raw': 'Entry with include_word and exclude_word'}]
|
||||
|
||||
return entries
|
Loading…
Reference in a new issue