diff --git a/sijapi/__init__.py b/sijapi/__init__.py index 05c74a4..79121bb 100644 --- a/sijapi/__init__.py +++ b/sijapi/__init__.py @@ -7,7 +7,8 @@ import multiprocessing from dotenv import load_dotenv from dateutil import tz from pathlib import Path -from .classes import Logger, Configuration, APIConfig, DirConfig, Geocoder +from .classes import Logger, Configuration, APIConfig, Database, DirConfig, Geocoder + # INITIALization BASE_DIR = Path(__file__).resolve().parent @@ -19,17 +20,14 @@ os.makedirs(LOGS_DIR, exist_ok=True) L = Logger("Central", LOGS_DIR) # API essentials -API = APIConfig.load('api', 'secrets') +API = APIConfig.load('sys', 'secrets') Dir = DirConfig.load('dirs') +Db = Database.load('sys') -print(f"Data: {Dir.DATA}") -print(f"Config: {Dir.CONFIG}") -print(f"Logs: {Dir.LOGS}") -print(f"Podcast: {Dir.PODCAST}") +# HOST = f"{API.BIND}:{API.PORT}" +# LOCAL_HOSTS = [ipaddress.ip_address(localhost.strip()) for localhost in os.getenv('LOCAL_HOSTS', '127.0.0.1').split(',')] + ['localhost'] +# SUBNET_BROADCAST = os.getenv("SUBNET_BROADCAST", '10.255.255.255') -HOST = f"{API.BIND}:{API.PORT}" -LOCAL_HOSTS = [ipaddress.ip_address(localhost.strip()) for localhost in os.getenv('LOCAL_HOSTS', '127.0.0.1').split(',')] + ['localhost'] -SUBNET_BROADCAST = os.getenv("SUBNET_BROADCAST", '10.255.255.255') MAX_CPU_CORES = min(int(os.getenv("MAX_CPU_CORES", int(multiprocessing.cpu_count()/2))), multiprocessing.cpu_count()) IMG = Configuration.load('img', 'secrets', Dir) @@ -40,23 +38,6 @@ Scrape = Configuration.load('scrape', 'secrets', Dir) Serve = Configuration.load('serve', 'secrets', Dir) Tts = Configuration.load('tts', 'secrets', Dir) -# print(f"Tts configuration loaded: {Tts}") -# print(f"Tts.elevenlabs: {Tts.elevenlabs}") -# print(f"Tts.elevenlabs.key: {Tts.elevenlabs.key}") -# print(f"Tts.elevenlabs.voices: {Tts.elevenlabs.voices}") -# print(f"Configuration.resolve_placeholders method: {Configuration.resolve_placeholders}") -# print(f"Configuration.resolve_string_placeholders method: {Configuration.resolve_string_placeholders}") -# print(f"Secrets in Tts config: {[attr for attr in dir(Tts) if attr.isupper()]}") -# print(f"Type of Tts.elevenlabs: {type(Tts.elevenlabs)}") -# print(f"Attributes of Tts.elevenlabs: {dir(Tts.elevenlabs)}") -# print(f"ElevenLabs API key (masked): {'*' * len(Tts.elevenlabs.key) if hasattr(Tts.elevenlabs, 'key') else 'Not found'}") -# print(f"Type of Tts.elevenlabs.voices: {type(Tts.elevenlabs.voices)}") -# print(f"Attributes of Tts.elevenlabs.voices: {dir(Tts.elevenlabs.voices)}") -# print(f"Default voice: {Tts.elevenlabs.default if hasattr(Tts.elevenlabs, 'default') else 'Not found'}") -# print(f"Is 'get' method available on Tts.elevenlabs.voices? {'get' in dir(Tts.elevenlabs.voices)}") -# print(f"Is 'values' method available on Tts.elevenlabs.voices? {'values' in dir(Tts.elevenlabs.voices)}") -# print("Initialization complete") - # Directories & general paths ROUTER_DIR = BASE_DIR / "routers" DATA_DIR = BASE_DIR / "data" diff --git a/sijapi/classes.py b/sijapi/classes.py index 4c06752..cec786b 100644 --- a/sijapi/classes.py +++ b/sijapi/classes.py @@ -76,6 +76,7 @@ def info(text: str): logger.info(text) def warn(text: str): logger.warning(text) def err(text: str): logger.error(text) def crit(text: str): logger.critical(text) + BASE_DIR = Path(__file__).resolve().parent CONFIG_DIR = BASE_DIR / "config" ENV_PATH = CONFIG_DIR / ".env" @@ -226,100 +227,297 @@ class Configuration(BaseModel): arbitrary_types_allowed = True -class DirConfig(BaseModel): - HOME: Path = Path.home() - +class DirConfig: + def __init__(self, config_data: dict): + self.BASE = Path(__file__).parent.parent + self.HOME = Path.home() + self.DATA = self.BASE / "data" + + for key, value in config_data.items(): + setattr(self, key, self._resolve_path(value)) + + def _resolve_path(self, path: str) -> Path: + path = path.replace("{{ BASE }}", str(self.BASE)) + path = path.replace("{{ HOME }}", str(self.HOME)) + path = path.replace("{{ DATA }}", str(self.DATA)) + return Path(path).expanduser() + @classmethod def load(cls, yaml_path: Union[str, Path]) -> 'DirConfig': - yaml_path = cls._resolve_path(yaml_path, 'config') + yaml_path = Path(yaml_path) + if not yaml_path.is_absolute(): + yaml_path = Path(__file__).parent / "config" / yaml_path + + if not yaml_path.suffix: + yaml_path = yaml_path.with_suffix('.yaml') - try: - with yaml_path.open('r') as file: - config_data = yaml.safe_load(file) + with yaml_path.open('r') as file: + config_data = yaml.safe_load(file) - print(f"Loaded configuration data from {yaml_path}") - - # Ensure HOME is set - if 'HOME' not in config_data: - config_data['HOME'] = str(Path.home()) - print(f"HOME was not in config, set to default: {config_data['HOME']}") - - # Create a temporary instance to resolve placeholders - temp_instance = cls.create_dynamic_model(**config_data) - resolved_data = temp_instance.resolve_placeholders(config_data) - - # Create the final instance with resolved data - return cls.create_dynamic_model(**resolved_data) - - except Exception as e: - print(f"Error loading configuration: {str(e)}") - raise + return cls(config_data) - @classmethod - def _resolve_path(cls, path: Union[str, Path], default_dir: str) -> Path: + +class Database: + def __init__(self, config_path: str): + self.config = self.load_config(config_path) + self.pool_connections = {} + self.local_ts_id = self.get_local_ts_id() + + def load_config(self, config_path: str) -> Dict[str, Any]: base_path = Path(__file__).parent.parent - path = Path(path) - if not path.suffix: - path = base_path / 'sijapi' / default_dir / f"{path.name}.yaml" - elif not path.is_absolute(): - path = base_path / path - return path + full_path = base_path / "sijapi" / "config" / f"{config_path}.yaml" + + with open(full_path, 'r') as file: + config = yaml.safe_load(file) + + return config - def resolve_placeholders(self, data: Any) -> Any: - if isinstance(data, dict): - resolved_data = {k: self.resolve_placeholders(v) for k, v in data.items()} - home_dir = Path(resolved_data.get('HOME', self.HOME)).expanduser() - base_dir = Path(__file__).parent.parent - data_dir = base_dir / "data" - resolved_data['HOME'] = str(home_dir) - resolved_data['BASE'] = str(base_dir) - resolved_data['DATA'] = str(data_dir) - return resolved_data - elif isinstance(data, list): - return [self.resolve_placeholders(v) for v in data] - elif isinstance(data, str): - return self.resolve_string_placeholders(data) + def get_local_ts_id(self) -> str: + return os.environ.get('TS_ID') + + async def get_connection(self, ts_id: str = None): + if ts_id is None: + ts_id = self.local_ts_id + + if ts_id not in self.pool_connections: + db_info = next((db for db in self.config['POOL'] if db['ts_id'] == ts_id), None) + if db_info is None: + raise ValueError(f"No database configuration found for TS_ID: {ts_id}") + + self.pool_connections[ts_id] = await asyncpg.create_pool( + host=db_info['ts_ip'], + port=db_info['db_port'], + user=db_info['db_user'], + password=db_info['db_pass'], + database=db_info['db_name'], + min_size=1, + max_size=10 + ) + + return await self.pool_connections[ts_id].acquire() + + async def release_connection(self, ts_id: str, connection): + await self.pool_connections[ts_id].release(connection) + + async def get_online_servers(self) -> List[str]: + online_servers = [] + for db_info in self.config['POOL']: + try: + conn = await self.get_connection(db_info['ts_id']) + await self.release_connection(db_info['ts_id'], conn) + online_servers.append(db_info['ts_id']) + except: + pass + return online_servers + + async def initialize_query_tracking(self): + conn = await self.get_connection() + try: + await conn.execute(""" + CREATE TABLE IF NOT EXISTS query_tracking ( + id SERIAL PRIMARY KEY, + ts_id TEXT NOT NULL, + query TEXT NOT NULL, + args JSONB, + executed_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP, + completed_by JSONB DEFAULT '{}'::jsonb, + result_checksum TEXT + ) + """) + finally: + await self.release_connection(self.local_ts_id, conn) + + async def execute_read(self, query: str, *args): + conn = await self.get_connection() + try: + return await conn.fetch(query, *args) + finally: + await self.release_connection(self.local_ts_id, conn) + + async def execute_write(self, query: str, *args): + # Execute write on local database + local_conn = await self.get_connection() + try: + await local_conn.execute(query, *args) + + # Log the query + query_id = await local_conn.fetchval(""" + INSERT INTO query_tracking (ts_id, query, args) + VALUES ($1, $2, $3) + RETURNING id + """, self.local_ts_id, query, json.dumps(args)) + finally: + await self.release_connection(self.local_ts_id, local_conn) + + # Calculate checksum + checksum = await self.compute_checksum(query, *args) + + # Update query_tracking with checksum + await self.update_query_checksum(query_id, checksum) + + # Replicate to online servers + online_servers = await self.get_online_servers() + for ts_id in online_servers: + if ts_id != self.local_ts_id: + asyncio.create_task(self._replicate_write(ts_id, query_id, query, args, checksum)) + + async def get_primary_server(self) -> str: + url = urljoin(self.config['URL'], '/id') + + async with aiohttp.ClientSession() as session: + try: + async with session.get(url) as response: + if response.status == 200: + primary_ts_id = await response.text() + return primary_ts_id.strip() + else: + logging.error(f"Failed to get primary server. Status: {response.status}") + return None + except aiohttp.ClientError as e: + logging.error(f"Error connecting to load balancer: {str(e)}") + return None + + async def get_checksum_server(self) -> dict: + primary_ts_id = await self.get_primary_server() + online_servers = await self.get_online_servers() + + checksum_servers = [server for server in self.config['POOL'] if server['ts_id'] in online_servers and server['ts_id'] != primary_ts_id] + + if not checksum_servers: + return next(server for server in self.config['POOL'] if server['ts_id'] == primary_ts_id) + + return random.choice(checksum_servers) + + async def compute_checksum(self, query: str, *args): + checksum_server = await self.get_checksum_server() + + if checksum_server['ts_id'] == self.local_ts_id: + return await self._local_compute_checksum(query, *args) else: - return data + return await self._delegate_compute_checksum(checksum_server, query, *args) - def resolve_string_placeholders(self, value: str) -> Path: - pattern = r'\{\{\s*([^}]+)\s*\}\}' - matches = re.findall(pattern, value) + async def _local_compute_checksum(self, query: str, *args): + conn = await self.get_connection() + try: + result = await conn.fetch(query, *args) + checksum = hashlib.md5(str(result).encode()).hexdigest() + return checksum + finally: + await self.release_connection(self.local_ts_id, conn) + + async def _delegate_compute_checksum(self, server: dict, query: str, *args): + url = f"http://{server['ts_ip']}:{server['app_port']}/sync/checksum" - for match in matches: - if match == 'HOME': - replacement = str(self.HOME) - elif match == 'BASE': - replacement = str(Path(__file__).parent.parent) - elif match == 'DATA': - replacement = str(Path(__file__).parent.parent / "data") - elif hasattr(self, match): - replacement = str(getattr(self, match)) - else: - replacement = value + async with aiohttp.ClientSession() as session: + try: + async with session.post(url, json={"query": query, "args": list(args)}) as response: + if response.status == 200: + result = await response.json() + return result['checksum'] + else: + logging.error(f"Failed to get checksum from {server['ts_id']}. Status: {response.status}") + return await self._local_compute_checksum(query, *args) + except aiohttp.ClientError as e: + logging.error(f"Error connecting to {server['ts_id']} for checksum: {str(e)}") + return await self._local_compute_checksum(query, *args) + + async def update_query_checksum(self, query_id: int, checksum: str): + conn = await self.get_connection() + try: + await conn.execute(""" + UPDATE query_tracking + SET result_checksum = $1 + WHERE id = $2 + """, checksum, query_id) + finally: + await self.release_connection(self.local_ts_id, conn) + + async def _replicate_write(self, ts_id: str, query_id: int, query: str, args: tuple, expected_checksum: str): + try: + conn = await self.get_connection(ts_id) + try: + await conn.execute(query, *args) + actual_checksum = await self.compute_checksum(query, *args) + if actual_checksum != expected_checksum: + raise ValueError(f"Checksum mismatch on {ts_id}") + await self.mark_query_completed(query_id, ts_id) + finally: + await self.release_connection(ts_id, conn) + except Exception as e: + logging.error(f"Failed to replicate write on {ts_id}: {str(e)}") + + async def mark_query_completed(self, query_id: int, ts_id: str): + conn = await self.get_connection() + try: + await conn.execute(""" + UPDATE query_tracking + SET completed_by = completed_by || jsonb_build_object($1, true) + WHERE id = $2 + """, ts_id, query_id) + finally: + await self.release_connection(self.local_ts_id, conn) + + async def sync_local_server(self): + conn = await self.get_connection() + try: + last_synced_id = await conn.fetchval(""" + SELECT COALESCE(MAX(id), 0) FROM query_tracking + WHERE completed_by ? $1 + """, self.local_ts_id) + + unexecuted_queries = await conn.fetch(""" + SELECT id, query, args, result_checksum + FROM query_tracking + WHERE id > $1 + ORDER BY id + """, last_synced_id) + + for query in unexecuted_queries: + try: + await conn.execute(query['query'], *json.loads(query['args'])) + actual_checksum = await self.compute_checksum(query['query'], *json.loads(query['args'])) + if actual_checksum != query['result_checksum']: + raise ValueError(f"Checksum mismatch for query ID {query['id']}") + await self.mark_query_completed(query['id'], self.local_ts_id) + except Exception as e: + logging.error(f"Failed to execute query ID {query['id']} during local sync: {str(e)}") + + logging.info(f"Local server sync completed. Executed {len(unexecuted_queries)} queries.") + + finally: + await self.release_connection(self.local_ts_id, conn) + + async def purge_completed_queries(self): + conn = await self.get_connection() + try: + all_ts_ids = [db['ts_id'] for db in self.config['POOL']] + result = await conn.execute(""" + WITH consecutive_completed AS ( + SELECT id, + row_number() OVER (ORDER BY id) AS rn + FROM query_tracking + WHERE completed_by ?& $1 + ) + DELETE FROM query_tracking + WHERE id IN ( + SELECT id + FROM consecutive_completed + WHERE rn = (SELECT MAX(rn) FROM consecutive_completed) + ) + """, all_ts_ids) + deleted_count = int(result.split()[-1]) + logging.info(f"Purged {deleted_count} completed queries.") + finally: + await self.release_connection(self.local_ts_id, conn) + + async def close(self): + for pool in self.pool_connections.values(): + await pool.close() - value = value.replace('{{' + match + '}}', replacement) - - return Path(value).expanduser() - @classmethod - def create_dynamic_model(cls, **data): - DynamicModel = create_model( - f'Dynamic{cls.__name__}', - __base__=cls, - **{k: (Path, v) for k, v in data.items()} - ) - return DynamicModel(**data) - - class Config: - arbitrary_types_allowed = True - - - - -# Configuration class for API & Database methods. +# Configuration class for API & Database methods. class APIConfig(BaseModel): HOST: str PORT: int diff --git a/sijapi/config/.env-example b/sijapi/config/.env-example index 51c3a18..b40a9da 100644 --- a/sijapi/config/.env-example +++ b/sijapi/config/.env-example @@ -94,7 +94,7 @@ TRUSTED_SUBNETS=127.0.0.1/32,10.13.37.0/24,100.64.64.0/24 # ────────── # #─── router selection: ──────────────────────────────────────────────────────────── -ROUTERS=asr,cal,cf,email,health,llm,loc,note,rag,img,serve,time,tts,weather +ROUTERS=asr,cal,cf,email,llm,loc,note,rag,img,serve,sys,time,tts,weather UNLOADED=ig #─── notes: ────────────────────────────────────────────────────────────────────── # diff --git a/sijapi/config/api.yaml-example b/sijapi/config/sys.yaml-example similarity index 99% rename from sijapi/config/api.yaml-example rename to sijapi/config/sys.yaml-example index ff5c0b2..0f23a68 100644 --- a/sijapi/config/api.yaml-example +++ b/sijapi/config/sys.yaml-example @@ -27,7 +27,6 @@ MODULES: dist: off email: on gis: on - health: on ig: off img: on llm: on @@ -36,6 +35,7 @@ MODULES: rag: off scrape: on serve: on + sys: on timing: on tts: on weather: on diff --git a/sijapi/routers/gis.py b/sijapi/routers/gis.py index 9940fc6..170b85e 100644 --- a/sijapi/routers/gis.py +++ b/sijapi/routers/gis.py @@ -270,39 +270,35 @@ Generate a heatmap for the given date range and save it as a PNG file using Foli else: end_date = start_date.replace(hour=23, minute=59, second=59) - # Fetch locations locations = await fetch_locations(start_date, end_date) if not locations: raise ValueError("No locations found for the given date range") - # Create map m = folium.Map() - - # Prepare heatmap data heat_data = [[loc.latitude, loc.longitude] for loc in locations] - - # Add heatmap layer HeatMap(heat_data).add_to(m) - # Fit the map to the bounds of all locations bounds = [ [min(loc.latitude for loc in locations), min(loc.longitude for loc in locations)], [max(loc.latitude for loc in locations), max(loc.longitude for loc in locations)] ] m.fit_bounds(bounds) - - # Generate output path if not provided - if output_path is None: - output_path, relative_path = assemble_journal_path(end_date, filename="map", extension=".png", no_timestamp=True) - - # Save the map as PNG - m.save(str(output_path)) - - info(f"Heatmap saved as PNG: {output_path}") - return output_path + + try: + if output_path is None: + output_path, relative_path = assemble_journal_path(end_date, filename="map", extension=".png", no_timestamp=True) + + m.save(str(output_path)) + + info(f"Heatmap saved as PNG: {output_path}") + return output_path + + except Exception as e: + err(f"Error saving heatmap: {str(e)}") + raise except Exception as e: - err(f"Error generating and saving heatmap: {str(e)}") + err(f"Error generating heatmap: {str(e)}") raise async def generate_map(start_date: datetime, end_date: datetime, max_points: int): diff --git a/sijapi/routers/note.py b/sijapi/routers/note.py index 5f8dae0..476563a 100644 --- a/sijapi/routers/note.py +++ b/sijapi/routers/note.py @@ -337,8 +337,8 @@ Obsidian helper. Takes a datetime and creates a new daily note. Note: it uses th _, note_path = assemble_journal_path(date_time, filename="Notes", extension=".md", no_timestamp = True) note_embed = f"![[{note_path}]]" - _, map_path = assemble_journal_path(date_time, filename="Map", extension=".png", no_timestamp = True) - map = await gis.generate_and_save_heatmap(date_time, output_path=map_path) + absolute_map_path, map_path = assemble_journal_path(date_time, filename="Map", extension=".png", no_timestamp = True) + map = await gis.generate_and_save_heatmap(date_time, output_path=absolute_map_path) map_embed = f"![[{map_path}]]" _, banner_path = assemble_journal_path(date_time, filename="Banner", extension=".jpg", no_timestamp = True) diff --git a/sijapi/routers/health.py b/sijapi/routers/sys.py similarity index 79% rename from sijapi/routers/health.py rename to sijapi/routers/sys.py index 67773c4..5a81fb5 100644 --- a/sijapi/routers/health.py +++ b/sijapi/routers/sys.py @@ -1,9 +1,7 @@ ''' -Health check module. /health returns `'status': 'ok'`, /id returns TS_ID, /routers responds with a list of the active routers, /ip responds with the device's local IP, /ts_ip responds with its tailnet IP, and /wan_ip responds with WAN IP. -Depends on: - TS_ID, LOGGER, SUBNET_BROADCAST +System module. /health returns `'status': 'ok'`, /id returns TS_ID, /routers responds with a list of the active routers, /ip responds with the device's local IP, /ts_ip responds with its tailnet IP, and /wan_ip responds with WAN IP. ''' -#routers/health.py +#routers/sys.py import os import httpx @@ -12,7 +10,7 @@ from fastapi import APIRouter from tailscale import Tailscale from sijapi import L, API, TS_ID, SUBNET_BROADCAST -health = APIRouter(tags=["public", "trusted", "private"]) +sys = APIRouter(tags=["public", "trusted", "private"]) logger = L.get_module_logger("health") def debug(text: str): logger.debug(text) def info(text: str): logger.info(text) @@ -20,20 +18,20 @@ def warn(text: str): logger.warning(text) def err(text: str): logger.error(text) def crit(text: str): logger.critical(text) -@health.get("/health") +@sys.get("/health") def get_health(): return {"status": "ok"} -@health.get("/id") +@sys.get("/id") def get_health() -> str: return TS_ID -@health.get("/routers") +@sys.get("/routers") def get_routers() -> str: active_modules = [module for module, is_active in API.MODULES.__dict__.items() if is_active] return active_modules -@health.get("/ip") +@sys.get("/ip") def get_local_ip(): """Get the server's local IP address.""" s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) @@ -46,7 +44,7 @@ def get_local_ip(): s.close() return IP -@health.get("/wan_ip") +@sys.get("/wan_ip") async def get_wan_ip(): """Get the WAN IP address using Mullvad's API.""" async with httpx.AsyncClient() as client: @@ -59,7 +57,7 @@ async def get_wan_ip(): err(f"Error fetching WAN IP: {e}") return "Unavailable" -@health.get("/ts_ip") +@sys.get("/ts_ip") async def get_tailscale_ip(): """Get the Tailscale IP address.""" tailnet = os.getenv("TAILNET") diff --git a/sijapi/routers/tts.py b/sijapi/routers/tts.py index 1c718c6..4118b8e 100644 --- a/sijapi/routers/tts.py +++ b/sijapi/routers/tts.py @@ -203,7 +203,7 @@ async def generate_speech( speed: float = 1.1, podcast: bool = False, title: str = None, - output_dir = None + output_dir = None, ) -> str: debug(f"Entering generate_speech function") debug(f"API.EXTENSIONS: {API.EXTENSIONS}") @@ -213,14 +213,14 @@ async def generate_speech( debug(f"Type of Tts: {type(Tts)}") debug(f"Dir of Tts: {dir(Tts)}") - output_dir = Path(output_dir) if output_dir else TTS_OUTPUT_DIR - if not output_dir.exists(): - output_dir.mkdir(parents=True) + + use_output_dir = Path(output_dir) if output_dir else TTS_OUTPUT_DIR + if not use_output_dir.exists(): use_output_dir.mkdir(parents=True) try: model = model if model else await get_model(voice, voice_file) title = title if title else "TTS audio" - output_path = output_dir / f"{dt_datetime.now().strftime('%Y%m%d%H%M%S')} {title}.wav" + output_path = use_output_dir / f"{dt_datetime.now().strftime('%Y%m%d%H%M%S')} {title}.wav" debug(f"Model: {model}") debug(f"Voice: {voice}") @@ -228,7 +228,7 @@ async def generate_speech( if model == "eleven_turbo_v2" and getattr(API.EXTENSIONS, 'elevenlabs', False): info("Using ElevenLabs.") - audio_file_path = await elevenlabs_tts(model, text, voice, title, output_dir) + audio_file_path = await elevenlabs_tts(model, text, voice, title, use_output_dir) elif getattr(API.EXTENSIONS, 'xtts', False): info("Using XTTS2") audio_file_path = await local_tts(text, speed, voice, voice_file, podcast, bg_tasks, title, output_path) @@ -244,24 +244,29 @@ async def generate_speech( warn(f"No file exists at {audio_file_path}") if podcast: - podcast_path = Dir.PODCAST / audio_file_path + podcast_path = Dir.PODCAST / audio_file_path.name - shutil.copy(audio_file_path, podcast_path) - if podcast_path.exists(): - info(f"Saved to podcast path: {podcast_path}") - else: - warn(f"Podcast mode enabled, but failed to save to {podcast_path}") - if podcast_path != audio_file_path: - info(f"Podcast mode enabled, so we will remove {audio_file_path}") - bg_tasks.add_task(os.remove, audio_file_path) + shutil.copy(audio_file_path, podcast_path) + if podcast_path.exists(): + info(f"Saved to podcast path: {podcast_path}") + else: + warn(f"Podcast mode enabled, but failed to save to {podcast_path}") + + if output_dir and Path(output_dir) == use_output_dir: + debug(f"Keeping {audio_file_path} because it was specified") + + else: + info(f"Podcast mode enabled and output_dir not specified so we will remove {audio_file_path}") + bg_tasks.add_task(os.remove, audio_file_path) else: - warn(f"Podcast path set to same as audio file path...") - + warn(f"Podcast path is the same as audio file path. Using existing file.") + return podcast_path - + return audio_file_path + except Exception as e: err(f"Failed to generate speech: {e}") err(f"Traceback: {traceback.format_exc()}")