#!/usr/bin/env python3
"""
This script can handle three main scenarios, intelligently determining what to do based on the input:
1. A single PDF file → Flatten and re-OCR into a single searchable PDF
2. A folder of numbered image files (e.g., JPG/JPEG) → Combine and OCR into one searchable PDF
3. A folder of PDF files → Batch process each PDF file, flattening and re-OCRing each into its own searchable PDF

Usage:
    ./multi_mode_ocr.py <input_path> [--output <file_or_folder>] [--replace]
                        [--lang <language>] [--threads <num>] [--quiet]

Arguments:
    <input_path>     Path to either: (a) a single PDF file, (b) a folder of images,
                     or (c) a folder of PDF files. (Required)
    --output, -o     Desired output path. Interpreted differently depending on input:
                     - Single PDF or folder of images: Output is one PDF file.
                     - Folder of PDFs: Output is a folder containing new OCRed PDFs.
                     By default, if you do NOT use --replace:
                       - For a single PDF: appends "_searchable" before the ".pdf"
                       - For a folder of images: uses "<folder>_searchable.pdf"
                       - For a folder of multiple PDFs: each PDF gets its own
                         "_searchable" appended.
    --replace, -r    Overwrite the original PDF(s) instead of creating a new file
                     (this is only valid if the input is a PDF file or a folder of PDFs).
                     In Single PDF mode, replacement is **default** unless --output is provided.
    --lang, -l       OCR language (default: "eng").
    --threads, -t    Number of threads to use for OCR (default: auto-detect CPU cores).
    --quiet, -q      Suppress output messages (only errors are printed).

Dependencies:
    - Python 3
    - PIL (Pillow)
    - pytesseract (Tesseract OCR)
    - PyPDF2
    - pdf2image (for flattening PDFs)
    - concurrent.futures (built-in)
"""

import os
import sys
import time
import io
import argparse
import multiprocessing
from PIL import Image
import pytesseract
from PyPDF2 import PdfMerger
import concurrent.futures

try:
    from pdf2image import convert_from_path
except ImportError:
    convert_from_path = None


def process_image_ocr(args):
    """Process a single image with OCR and return the OCR PDF bytes."""
    img, lang, page_num, total_pages, verbose = args
    try:
        if verbose:
            sys.stdout.write(f"\rProcessing page {page_num}/{total_pages}...")
            sys.stdout.flush()
        pdf_bytes = pytesseract.image_to_pdf_or_hocr(img, extension='pdf', lang=lang)
        return pdf_bytes
    except Exception as e:
        print(f"\nError processing page {page_num}: {e}")
        return None


def create_searchable_pdf_from_images(image_list, output_pdf, language='eng',
                                      threads=None, verbose=True):
    """
    Create a searchable PDF from a list of PIL Images using OCR and write to output_pdf.
    """
    start_time = time.time()
    total_pages = len(image_list)
    if total_pages == 0:
        if verbose:
            print("No images found.")
        return False

    if threads is None:
        threads = multiprocessing.cpu_count()

    tasks = [
        (image_list[i], language, i+1, total_pages, verbose)
        for i in range(total_pages)
    ]

    with concurrent.futures.ThreadPoolExecutor(max_workers=threads) as executor:
        ocr_pdfs = list(executor.map(process_image_ocr, tasks))

    if verbose:
        print("\nMerging OCRed pages...")

    merger = PdfMerger()
    for pdf_page in ocr_pdfs:
        if pdf_page:
            merger.append(io.BytesIO(pdf_page))

    try:
        with open(output_pdf, "wb") as f:
            merger.write(f)
    except Exception as e:
        print(f"Failed to write output PDF: {e}")
        return False

    elapsed_time = time.time() - start_time
    if verbose:
        print(f"OCR completed in {elapsed_time:.2f} seconds.")
        print(f"Searchable PDF created: {output_pdf}")

    return True


def create_searchable_pdf_from_directory_of_images(
    input_dir, output_pdf, language='eng',
    threads=None, verbose=True
):
    """
    Collect all JPG/JPEG images in input_dir, sort them, and create a single searchable PDF.
    """
    if verbose:
        print(f"Processing images in folder '{input_dir}'...")

    # Collect all .jpg or .jpeg
    image_files = sorted([
        os.path.join(input_dir, f) for f in os.listdir(input_dir)
        if f.lower().endswith(('.jpg', '.jpeg'))
    ])

    pil_images = []
    for img_path in image_files:
        try:
            pil_images.append(Image.open(img_path))
        except Exception as e:
            if verbose:
                print(f"Skipping {img_path} due to error: {e}")

    if verbose:
        print(f"Found {len(pil_images)} images to process.")

    return create_searchable_pdf_from_images(
        pil_images, output_pdf, language, threads, verbose
    )


def flatten_pdf_to_images(input_pdf, dpi=300):
    """
    Convert each page of a PDF to a list of PIL Images using pdf2image.
    Returns a list of PIL Images.
    """
    if convert_from_path is None:
        raise RuntimeError("pdf2image is not installed. Cannot flatten PDFs.")
    return convert_from_path(input_pdf, dpi=dpi)


def flatten_and_ocr_pdf(input_pdf, output_pdf, language='eng',
                        threads=None, verbose=True):
    """
    Flatten an existing PDF to images, then re-OCR into a new searchable PDF.
    """
    if verbose:
        print(f"Flattening PDF '{input_pdf}' at 300 dpi...")

    try:
        pil_images = flatten_pdf_to_images(input_pdf, dpi=300)
    except Exception as e:
        print(f"Failed to convert PDF to images: {e}")
        return False

    if verbose:
        print(f"PDF has {len(pil_images)} pages. Starting OCR...")

    return create_searchable_pdf_from_images(
        pil_images, output_pdf,
        language=language,
        threads=threads,
        verbose=verbose
    )


def batch_flatten_and_ocr_pdfs(pdf_files, output_folder, language='eng',
                               threads=None, replace=False, verbose=True):
    """
    Batch-process a list of PDFs: flatten and re-OCR each.
    - If replace=True, overwrites each original PDF
    - Otherwise, outputs to output_folder each with "_searchable" appended
    """
    if not pdf_files:
        if verbose:
            print("No PDF files found in the folder.")
        return False

    if verbose:
        print(f"Found {len(pdf_files)} PDFs to process.")

    success = True
    for pdf_path in pdf_files:
        base_name = os.path.splitext(os.path.basename(pdf_path))[0]
        if replace:
            # Overwrite the original
            out_path = pdf_path
        else:
            out_path = os.path.join(
                output_folder, f"{base_name}_searchable.pdf"
            )

        if not flatten_and_ocr_pdf(pdf_path, out_path, language, threads, verbose):
            success = False

    return success


def determine_input_mode(input_path, verbose=True):
    """
    Determine which of the three modes we're in:
      1) Single PDF file
      2) Folder of images
      3) Folder of PDFs
    Returns a tuple: (mode, items) where "mode" is one of
      "single_pdf", "folder_images", "folder_pdfs"
    or None if it can't be determined properly.
    """
    if os.path.isfile(input_path):
        # If single file, check if it's a PDF
        if input_path.lower().endswith('.pdf'):
            return ('single_pdf', input_path)
        else:
            # Could be a single image, but the user scenario mentions a "folder of images"
            # so we'll not handle single-image logic. We'll just treat this as an error:
            if verbose:
                print("ERROR: Single file is not a PDF. Exiting.")
            return (None, None)
    elif os.path.isdir(input_path):
        # Possibly a folder of images or a folder of PDFs
        # Let's see what's inside
        all_files = os.listdir(input_path)

        pdf_files = [
            os.path.join(input_path, f)
            for f in all_files
            if f.lower().endswith('.pdf')
        ]
        image_files = [
            os.path.join(input_path, f)
            for f in all_files
            if f.lower().endswith(('.jpg', '.jpeg'))
        ]

        if len(pdf_files) > 0 and len(image_files) == 0:
            # There's at least one PDF and no images → folder of PDFs
            if len(pdf_files) == 1:
                # Edge case: exactly one PDF in the folder. Treat as single_pdf.
                return ('single_pdf', pdf_files[0])
            else:
                return ('folder_pdfs', pdf_files)
        elif len(pdf_files) == 0 and len(image_files) > 0:
            # It's likely a folder of images
            return ('folder_images', image_files)
        else:
            # Mixed or empty
            # If there's at least one image and no PDFs, we do folder_images.
            # If there's at least one PDF and no images, we do folder_pdfs.
            # If there's a mixture or nothing, handle or raise an error.
            if len(pdf_files) > 0 and len(image_files) > 0:
                if verbose:
                    print("ERROR: The folder contains both images and PDFs. "
                          "Please separate them or specify a single PDF file.")
                return (None, None)
            if len(pdf_files) == 0 and len(image_files) == 0:
                if verbose:
                    print("ERROR: The folder is empty or doesn't contain PDFs or JPGs.")
                return (None, None)
    else:
        if verbose:
            print("ERROR: Input path is neither a file nor a folder.")
        return (None, None)


def main():
    parser = argparse.ArgumentParser(
        description="Create a searchable PDF from either: "
                    "a PDF file (flatten, re-OCR), "
                    "a folder of numbered image files, "
                    "or a folder of PDF files (batch)."
    )

    # Change input to a positional argument
    parser.add_argument('input_path',
                        help='Path to file/folder input (PDF file, folder of images, or folder of PDFs).')
    parser.add_argument('--output', '-o',
                        help='Output path. Interpretation depends on input: '
                             'single file/folder-of-images => single PDF file, '
                             'folder-of-pdfs => output folder for new PDFs. '
                             'Default: appends "_searchable" to new PDFs if not using --replace.')
    # Modify --replace to have default behavior based on mode
    parser.add_argument('--replace', '-r', action='store_true',
                        help='Overwrite the original PDF(s). '
                             'Only valid if input is PDF(s).')
    parser.add_argument('--lang', '-l', default='eng',
                        help='OCR language (default: eng)')
    parser.add_argument('--threads', '-t', type=int,
                        help='Number of OCR threads (default: # of CPU cores).')
    parser.add_argument('--quiet', '-q', action='store_true',
                        help='Minimize output messages.')
    args = parser.parse_args()

    verbose = not args.quiet
    input_path = os.path.normpath(args.input_path)
    mode, items = determine_input_mode(input_path, verbose=verbose)

    if mode is None:
        sys.exit(1)  # an error has already been printed

    # Initialize replace flag
    replace = args.replace

    if mode == 'single_pdf':
        # items is the path to that single PDF
        pdf_path = items
        if args.output:
            # If --output is provided, do not replace; output to specified path
            output_pdf = args.output
            replace = False
        else:
            # No --output provided; replace is True by default
            output_pdf = pdf_path

        success = flatten_and_ocr_pdf(
            pdf_path, output_pdf,
            language=args.lang,
            threads=args.threads,
            verbose=verbose
        )
        if not success:
            sys.exit(1)

    elif mode == 'folder_images':
        input_dir = input_path
        # There's no concept of replace for images → ignore if user set --replace
        if args.replace:
            if verbose:
                print("Warning: --replace has no effect for folder-of-images input.")
        if not args.output:
            # By default, produce "<folder>_searchable.pdf"
            folder_name = os.path.basename(os.path.normpath(input_dir))
            output_pdf = os.path.join(input_dir, f"{folder_name}_searchable.pdf")
        else:
            output_pdf = args.output

        success = create_searchable_pdf_from_directory_of_images(
            input_dir, output_pdf, language=args.lang,
            threads=args.threads, verbose=verbose
        )
        if not success:
            sys.exit(1)

    elif mode == 'folder_pdfs':
        # items is the list of PDF files
        pdf_files = items

        # If there's only one PDF in the folder, we treat it as single_pdf above.
        # Here, mode is 'folder_pdfs' only if multiple PDFs exist.

        if len(pdf_files) == 0:
            if verbose:
                print("No PDFs found in folder.")
            sys.exit(1)

        if replace:
            # Overwrite each PDF in place, ignore --output
            success = batch_flatten_and_ocr_pdfs(
                pdf_files, output_folder=None,
                language=args.lang,
                threads=args.threads,
                replace=True,
                verbose=verbose
            )
            if not success:
                sys.exit(1)
        else:
            # Need an output folder
            if not args.output:
                # By default, create a subfolder next to the input folder
                # named something like "OCRed_PDFs"
                base_dir = input_path
                output_folder = os.path.join(base_dir, "OCRed_PDFs")
                if verbose:
                    print(f"No output folder specified; using '{output_folder}'.")
            else:
                output_folder = os.path.normpath(args.output)

            # Create output folder if it doesn't exist
            if not os.path.exists(output_folder):
                try:
                    os.makedirs(output_folder, exist_ok=True)
                except Exception as e:
                    print(f"ERROR: Could not create output folder: {e}")
                    sys.exit(1)

            success = batch_flatten_and_ocr_pdfs(
                pdf_files, output_folder=output_folder,
                language=args.lang,
                threads=args.threads,
                replace=False,
                verbose=verbose
            )
            if not success:
                sys.exit(1)

    else:
        # Shouldn't get here
        sys.exit(1)


if __name__ == "__main__":
    main()