sijapi/sijapi/utilities.py
2024-07-22 12:19:31 -07:00

481 lines
17 KiB
Python

# utilities.py
import re
import os
from fastapi import Form
import re
import io
from io import BytesIO
import base64
import math
import paramiko
from dateutil import parser
from pathlib import Path
import filetype
from PyPDF2 import PdfReader
from pdfminer.high_level import extract_text as pdfminer_extract_text
import pytesseract
from pdf2image import convert_from_path
from datetime import datetime, date, time
from typing import Optional, Union, Tuple
import asyncio
from PIL import Image
import pandas as pd
import ipaddress
from scipy.spatial import cKDTree
from dateutil.parser import parse as dateutil_parse
from docx import Document
from sshtunnel import SSHTunnelForwarder
from fastapi import Depends, HTTPException, Request, UploadFile
from fastapi.security.api_key import APIKeyHeader
from sijapi import L, API, YEAR_FMT, MONTH_FMT, DAY_FMT, DAY_SHORT_FMT, OBSIDIAN_VAULT_DIR, ALLOWED_FILENAME_CHARS, MAX_PATH_LENGTH, ARCHIVE_DIR
logger = L.get_module_logger('utilities')
def debug(text: str): logger.debug(text)
def info(text: str): logger.info(text)
def warn(text: str): logger.warning(text)
def err(text: str): logger.error(text)
def crit(text: str): logger.critical(text)
api_key_header = APIKeyHeader(name="Authorization", auto_error=False)
def validate_api_key(request: Request, api_key: str = Depends(api_key_header)):
if request.url.path in API.PUBLIC:
return
client_ip = ipaddress.ip_address(request.client.host)
trusted_subnets = [ipaddress.ip_network(subnet) for subnet in API.TRUSTED_SUBNETS]
if any(client_ip in subnet for subnet in trusted_subnets):
return
# Check header-based API key
if api_key:
if api_key.lower().startswith("bearer "):
api_key = api_key.lower().split("bearer ")[-1]
if api_key in API.KEYS:
return
# Check query-based API key
api_key_query = request.query_params.get("api_key")
if api_key_query in API.KEYS:
return
raise HTTPException(status_code=401, detail="Invalid or missing API key")
def assemble_archive_path(filename: str, extension: str = ".md", date_time: datetime = datetime.now(), subdir: str = None) -> Tuple[Path, Path]:
year = date_time.strftime(YEAR_FMT)
month = date_time.strftime(MONTH_FMT)
day = date_time.strftime(DAY_FMT)
day_short = date_time.strftime(DAY_SHORT_FMT)
timestamp = date_time.strftime("%H%M%S")
# Ensure the extension is preserved
base_name, ext = os.path.splitext(filename)
extension = ext if ext else extension
# Initial sanitization
sanitized_base = sanitize_filename(base_name, '')
filename = f"{day_short} {timestamp} {sanitized_base}{extension}"
relative_path = Path(year) / month / day / filename
absolute_path = ARCHIVE_DIR / relative_path
# Ensure the total path length doesn't exceed MAX_PATH_LENGTH
while len(str(absolute_path)) > MAX_PATH_LENGTH:
# Truncate the sanitized_base, not the full filename
sanitized_base = sanitized_base[:-1]
filename = f"{day_short} {timestamp} {sanitized_base}{extension}"
relative_path = Path(year) / month / day / filename
absolute_path = ARCHIVE_DIR / relative_path
return absolute_path, relative_path
def assemble_journal_path(date_time: datetime, subdir: str = None, filename: str = None, extension: str = None, no_timestamp: bool = False) -> Tuple[Path, Path]:
'''
Obsidian helper. Takes a datetime and optional subdirectory name, filename, and extension.
If an extension is provided, it ensures the path is to a file with that extension.
If no extension is provided, it treats the path as a directory.
'''
year = date_time.strftime(YEAR_FMT)
month = date_time.strftime(MONTH_FMT)
day = date_time.strftime(DAY_FMT)
day_short = date_time.strftime(DAY_SHORT_FMT)
timestamp = date_time.strftime("%H%M%S")
relative_path = Path("journal") / year / month / day
if not subdir and not filename and not extension:
relative_path = relative_path / f"{day}.md"
else:
if subdir:
relative_path = relative_path / f"{day_short} {subdir}"
if filename:
if extension:
extension = extension if extension.startswith(".") else f".{extension}"
else:
extension = validate_extension(filename, [".md", ".m4a", ".wav", ".aiff", ".flac", ".mp3", ".mp4", ".pdf", ".js", ".json", ".yaml", ".py"]) or ".md"
filename = sanitize_filename(filename)
filename = f"{day_short} {filename}" if no_timestamp else f"{day_short} {timestamp} {filename}"
filename = f"{filename}{extension}" if not filename.endswith(extension) else filename
relative_path = relative_path / filename
else:
debug(f"This only happens, theoretically, when no filename nor subdirectory are provided, but an extension is. Which is kinda silly.")
return None, None
absolute_path = OBSIDIAN_VAULT_DIR / relative_path
os.makedirs(absolute_path.parent, exist_ok=True)
return absolute_path, relative_path
def validate_extension(filename, valid_extensions=None):
if valid_extensions is None:
return os.path.splitext(filename)
else:
extension = os.path.splitext(filename)[-1].lower()
return extension if extension in valid_extensions else None
def prefix_lines(text: str, prefix: str = '> ') -> str:
lines = text.split('\n')
prefixed_lines = [f"{prefix}{line.lstrip()}" for line in lines]
return '\n'.join(prefixed_lines)
def f(file):
if hasattr(file, 'read') and callable(file.read):
return file
if isinstance(file, (bytes, bytearray)):
return file
if isinstance(file, Path):
file_path = file
elif isinstance(file, str):
file_path = Path(file)
else:
raise TypeError("Invalid file type. Expected str, Path, or file-like object.")
with open(file_path, 'rb') as thefile:
return thefile
def get_extension(file):
try:
if isinstance(file, str):
file_path = Path(file)
elif isinstance(file, Path):
file_path = file
else:
file_path = Path(file.filename)
file_extension = file_path.suffix
return file_extension
except Exception as e:
err(f"Unable to get extension of {file}")
raise e
def sanitize_filename(text, extension: str = None, max_length: int = MAX_PATH_LENGTH):
"""Sanitize a string to be used as a safe filename while protecting the file extension."""
debug(f"Filename before sanitization: {text}")
text = re.sub(r'\s+', ' ', text)
sanitized = re.sub(ALLOWED_FILENAME_CHARS, '', text)
sanitized = sanitized.strip()
base_name, extension = os.path.splitext(sanitized)
max_base_length = max_length - len(extension)
if len(base_name) > max_base_length:
base_name = base_name[:max_base_length - 5].rstrip()
final_filename = base_name + extension
debug(f"Filename after sanitization: {final_filename}")
return final_filename
def check_file_name(file_name, max_length=255):
"""Check if the file name needs sanitization based on the criteria of the second sanitize_filename function."""
needs_sanitization = False
if len(file_name) > max_length:
debug(f"Filename exceeds maximum length of {max_length}: {file_name}")
needs_sanitization = True
if re.search(ALLOWED_FILENAME_CHARS, file_name):
debug(f"Filename contains non-word characters (except space, dot, and hyphen): {file_name}")
needs_sanitization = True
if re.search(r'\s{2,}', file_name):
debug(f"Filename contains multiple consecutive spaces: {file_name}")
needs_sanitization = True
if file_name != file_name.strip():
debug(f"Filename has leading or trailing spaces: {file_name}")
needs_sanitization = True
return needs_sanitization
def bool_convert(value: str = Form(None)):
return value.lower() in ["true", "1", "t", "y", "yes"]
def str_to_bool(value: str) -> bool:
"""
Convert a string to a boolean.
Interprets 'true', '1', 'yes', 'y' as True.
Interprets 'false', '0', 'no', 'n', '', or any other string as False.
"""
def get_timestamp():
return datetime.now().strftime("%Y-%m-%d %H:%M:%S")
async def extract_text(file_path: str) -> str:
"""Extract text from file."""
if file_path.endswith('.pdf'):
return await extract_text_from_pdf(file_path)
elif file_path.endswith('.docx'):
return await extract_text_from_docx(file_path)
def clean_text(text):
text = text.replace('-', '')
text = re.sub(r'\s+', ' ', text)
text = re.sub(r'[\u200B-\u200D\uFEFF]', '', text)
return text.strip()
async def ocr_pdf(file_path: str) -> str:
try:
images = await asyncio.to_thread(convert_from_path, file_path)
texts = await asyncio.gather(*(asyncio.to_thread(pytesseract.image_to_string, image) for image in images))
return ' '.join(texts)
except Exception as e:
err(f"Error during OCR: {str(e)}")
return ""
async def extract_text_from_pdf(file_path: str) -> str:
if not await is_valid_pdf(file_path):
err(f"Invalid PDF file: {file_path}")
return ""
text = ''
num_pages = 0
# First, attempt to extract text using PyPDF2
try:
reader = await asyncio.to_thread(PdfReader, file_path)
for page in reader.pages:
text_content = page.extract_text() + ' ' if page.extract_text() else ''
text += text_content
num_pages = len(reader.pages)
# If text was extracted successfully and it's deemed sufficient, return it
if text and not should_use_ocr(text, num_pages):
return clean_text(text)
except Exception as e:
err(f"Error extracting text with PyPDF2: {str(e)}")
# If PyPDF2 extraction fails or is insufficient, fall back to pdfminer.six
try:
text_pdfminer = await asyncio.to_thread(pdfminer_extract_text, file_path)
if text_pdfminer and not should_use_ocr(text_pdfminer, num_pages):
return clean_text(text_pdfminer)
except Exception as e:
err(f"Error extracting text with pdfminer.six: {e}")
# If both methods fail or are deemed insufficient, use OCR as the last resort
debug("Falling back to OCR for text extraction...")
return await ocr_pdf(file_path)
async def is_valid_pdf(file_path: str) -> bool:
"""Check if the file at file_path is a valid PDF."""
try:
kind = filetype.guess(file_path)
return kind.mime == 'application/pdf'
except Exception as e:
err(f"Error checking file type: {e}")
return False
async def extract_text_from_pdf(file_path: str) -> str:
if not await is_valid_pdf(file_path):
err(f"Invalid PDF file: {file_path}")
return ""
text = ''
try:
reader = await asyncio.to_thread(PdfReader, file_path)
for page in reader.pages:
text_content = page.extract_text() + ' ' if page.extract_text() else ''
text += text_content
if text.strip(): # Successfully extracted text
return clean_text(text)
except Exception as e:
err(f"Error extracting text with PyPDF2: {str(e)}")
try:
text_pdfminer = await asyncio.to_thread(pdfminer_extract_text, file_path)
if text_pdfminer.strip(): # Successfully extracted text
return clean_text(text_pdfminer)
except Exception as e:
err(f"Error extracting text with pdfminer.six: {str(e)}")
# Fall back to OCR
debug("Falling back to OCR for text extraction...")
try:
images = convert_from_path(file_path)
ocr_texts = await asyncio.gather(*(asyncio.to_thread(pytesseract.image_to_string, img) for img in images))
return ' '.join(ocr_texts).strip()
except Exception as e:
err(f"OCR failed: {str(e)}")
return ""
async def extract_text_from_docx(file_path: str) -> str:
def read_docx(file_path):
doc = Document(file_path)
full_text = [paragraph.text for paragraph in doc.paragraphs]
return '\n'.join(full_text)
return await asyncio.to_thread(read_docx, file_path)
# Correcting read_text_file to be asynchronous
async def read_text_file(file_path: str) -> str:
# This opens and reads a file asynchronously by offloading to a separate thread
return await asyncio.to_thread(_sync_read_text_file, file_path)
def _sync_read_text_file(file_path: str) -> str:
# Actual synchronous file reading operation
with open(file_path, 'r', encoding='utf-8') as file:
return file.read()
def should_use_ocr(text, num_pages) -> bool:
if not text:
return True # No text was extracted, use OCR
word_count = len(text.split())
avg_words_per_page = word_count / num_pages
return avg_words_per_page < 10
def convert_to_unix_time(iso_date_str):
dt = parser.parse(iso_date_str) # Automatically parses datetime with timezone
return int(dt.timestamp())
def haversine(lat1, lon1, lat2, lon2):
""" Calculate the great circle distance between two points on the earth specified in decimal degrees. """
lat1, lon1, lat2, lon2 = map(math.radians, [lat1, lon1, lat2, lon2])
dlat = lat2 - lat1
dlon = lon2 - lon1
a = math.sin(dlat / 2) ** 2 + math.cos(lat1) * math.cos(lat2) * math.sin(dlon / 2) ** 2
c = 2 * math.asin(math.sqrt(a))
r = 6371 # Radius of Earth in kilometers
return c * r
def convert_degrees_to_cardinal(d):
"""
Convert degrees to cardinal directions
"""
dirs = ["N", "NNE", "NE", "ENE", "E", "ESE", "SE", "SSE", "S", "SSW", "SW", "WSW", "W", "WNW", "NW", "NNW"]
ix = round(d / (360. / len(dirs)))
return dirs[ix % len(dirs)]
HOURLY_COLUMNS_MAPPING = {
"12am": "00:00:00",
"2am": "02:00:00",
"4am": "04:00:00",
"6am": "06:00:00",
"8am": "08:00:00",
"10am": "10:00:00",
"12pm": "12:00:00",
"2pm": "14:00:00",
"4pm": "16:00:00",
"6pm": "18:00:00",
"8pm": "20:00:00",
"10pm": "22:00:00",
}
def convert_to_12_hour_format(datetime_obj_or_str):
if isinstance(datetime_obj_or_str, str):
try:
datetime_obj = datetime.strptime(datetime_obj_or_str, "%Y-%m-%d %H:%M:%S")
except ValueError:
try:
datetime_obj = datetime.strptime(datetime_obj_or_str, "%H:%M:%S")
except ValueError:
return "Invalid datetime string format"
elif isinstance(datetime_obj_or_str, time):
datetime_obj_or_str = datetime_obj_or_str.strftime("%H:%M:%S")
else:
datetime_obj = datetime_obj_or_str
if isinstance(datetime_obj_or_str, str):
time24 = datetime_obj_or_str
else:
time24 = datetime_obj.strftime("%H:%M:%S")
reverse_mapping = {v: k for k, v in HOURLY_COLUMNS_MAPPING.items()}
return reverse_mapping.get(time24, "Invalid time")
def encode_image_to_base64(image_path):
if os.path.exists(image_path):
with Image.open(image_path) as image:
output_buffer = BytesIO()
image.save(output_buffer, format='JPEG')
byte_data = output_buffer.getvalue()
base64_str = base64.b64encode(byte_data).decode('utf-8')
return base64_str
else:
debug(f"Error: File does not exist at {image_path}")
def resize_and_convert_image(image_path, max_size=2160, quality=80):
with Image.open(image_path) as img:
# Resize image
ratio = max_size / max(img.size)
new_size = tuple([int(x * ratio) for x in img.size])
img = img.resize(new_size, Image.Resampling.LANCZOS)
# Convert to jpg
img_byte_arr = io.BytesIO()
img.save(img_byte_arr, format='JPEG', quality=quality)
img_byte_arr = img_byte_arr.getvalue()
return img_byte_arr
def load_geonames_data(path: str):
columns = ['geonameid', 'name', 'asciiname', 'alternatenames',
'latitude', 'longitude', 'feature_class', 'feature_code',
'country_code', 'cc2', 'admin1_code', 'admin2_code', 'admin3_code',
'admin4_code', 'population', 'elevation', 'dem', 'timezone', 'modification_date']
data = pd.read_csv(
path,
sep='\t',
header=None,
names=columns,
low_memory=False
)
return data
async def run_ssh_command(server, command):
try:
ssh = paramiko.SSHClient()
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
ssh.connect(server.ssh.host, username=server.ssh.user, password=server.ssh.password)
stdin, stdout, stderr = ssh.exec_command(command)
output = stdout.read().decode()
error = stderr.read().decode()
ssh.close()
return output, error
except Exception as e:
err(f"SSH command failed for server {server.id}: {str(e)}")
raise