First steps of significant database sync redesign

This commit is contained in:
sanj 2024-08-08 23:34:13 -07:00
parent c9c36b4c42
commit 46c5db23de
8 changed files with 335 additions and 157 deletions

View file

@ -7,7 +7,8 @@ import multiprocessing
from dotenv import load_dotenv from dotenv import load_dotenv
from dateutil import tz from dateutil import tz
from pathlib import Path from pathlib import Path
from .classes import Logger, Configuration, APIConfig, DirConfig, Geocoder from .classes import Logger, Configuration, APIConfig, Database, DirConfig, Geocoder
# INITIALization # INITIALization
BASE_DIR = Path(__file__).resolve().parent BASE_DIR = Path(__file__).resolve().parent
@ -19,17 +20,14 @@ os.makedirs(LOGS_DIR, exist_ok=True)
L = Logger("Central", LOGS_DIR) L = Logger("Central", LOGS_DIR)
# API essentials # API essentials
API = APIConfig.load('api', 'secrets') API = APIConfig.load('sys', 'secrets')
Dir = DirConfig.load('dirs') Dir = DirConfig.load('dirs')
Db = Database.load('sys')
print(f"Data: {Dir.DATA}") # HOST = f"{API.BIND}:{API.PORT}"
print(f"Config: {Dir.CONFIG}") # LOCAL_HOSTS = [ipaddress.ip_address(localhost.strip()) for localhost in os.getenv('LOCAL_HOSTS', '127.0.0.1').split(',')] + ['localhost']
print(f"Logs: {Dir.LOGS}") # SUBNET_BROADCAST = os.getenv("SUBNET_BROADCAST", '10.255.255.255')
print(f"Podcast: {Dir.PODCAST}")
HOST = f"{API.BIND}:{API.PORT}"
LOCAL_HOSTS = [ipaddress.ip_address(localhost.strip()) for localhost in os.getenv('LOCAL_HOSTS', '127.0.0.1').split(',')] + ['localhost']
SUBNET_BROADCAST = os.getenv("SUBNET_BROADCAST", '10.255.255.255')
MAX_CPU_CORES = min(int(os.getenv("MAX_CPU_CORES", int(multiprocessing.cpu_count()/2))), multiprocessing.cpu_count()) MAX_CPU_CORES = min(int(os.getenv("MAX_CPU_CORES", int(multiprocessing.cpu_count()/2))), multiprocessing.cpu_count())
IMG = Configuration.load('img', 'secrets', Dir) IMG = Configuration.load('img', 'secrets', Dir)
@ -40,23 +38,6 @@ Scrape = Configuration.load('scrape', 'secrets', Dir)
Serve = Configuration.load('serve', 'secrets', Dir) Serve = Configuration.load('serve', 'secrets', Dir)
Tts = Configuration.load('tts', 'secrets', Dir) Tts = Configuration.load('tts', 'secrets', Dir)
# print(f"Tts configuration loaded: {Tts}")
# print(f"Tts.elevenlabs: {Tts.elevenlabs}")
# print(f"Tts.elevenlabs.key: {Tts.elevenlabs.key}")
# print(f"Tts.elevenlabs.voices: {Tts.elevenlabs.voices}")
# print(f"Configuration.resolve_placeholders method: {Configuration.resolve_placeholders}")
# print(f"Configuration.resolve_string_placeholders method: {Configuration.resolve_string_placeholders}")
# print(f"Secrets in Tts config: {[attr for attr in dir(Tts) if attr.isupper()]}")
# print(f"Type of Tts.elevenlabs: {type(Tts.elevenlabs)}")
# print(f"Attributes of Tts.elevenlabs: {dir(Tts.elevenlabs)}")
# print(f"ElevenLabs API key (masked): {'*' * len(Tts.elevenlabs.key) if hasattr(Tts.elevenlabs, 'key') else 'Not found'}")
# print(f"Type of Tts.elevenlabs.voices: {type(Tts.elevenlabs.voices)}")
# print(f"Attributes of Tts.elevenlabs.voices: {dir(Tts.elevenlabs.voices)}")
# print(f"Default voice: {Tts.elevenlabs.default if hasattr(Tts.elevenlabs, 'default') else 'Not found'}")
# print(f"Is 'get' method available on Tts.elevenlabs.voices? {'get' in dir(Tts.elevenlabs.voices)}")
# print(f"Is 'values' method available on Tts.elevenlabs.voices? {'values' in dir(Tts.elevenlabs.voices)}")
# print("Initialization complete")
# Directories & general paths # Directories & general paths
ROUTER_DIR = BASE_DIR / "routers" ROUTER_DIR = BASE_DIR / "routers"
DATA_DIR = BASE_DIR / "data" DATA_DIR = BASE_DIR / "data"

View file

@ -76,6 +76,7 @@ def info(text: str): logger.info(text)
def warn(text: str): logger.warning(text) def warn(text: str): logger.warning(text)
def err(text: str): logger.error(text) def err(text: str): logger.error(text)
def crit(text: str): logger.critical(text) def crit(text: str): logger.critical(text)
BASE_DIR = Path(__file__).resolve().parent BASE_DIR = Path(__file__).resolve().parent
CONFIG_DIR = BASE_DIR / "config" CONFIG_DIR = BASE_DIR / "config"
ENV_PATH = CONFIG_DIR / ".env" ENV_PATH = CONFIG_DIR / ".env"
@ -226,100 +227,297 @@ class Configuration(BaseModel):
arbitrary_types_allowed = True arbitrary_types_allowed = True
class DirConfig(BaseModel): class DirConfig:
HOME: Path = Path.home() def __init__(self, config_data: dict):
self.BASE = Path(__file__).parent.parent
self.HOME = Path.home()
self.DATA = self.BASE / "data"
for key, value in config_data.items():
setattr(self, key, self._resolve_path(value))
def _resolve_path(self, path: str) -> Path:
path = path.replace("{{ BASE }}", str(self.BASE))
path = path.replace("{{ HOME }}", str(self.HOME))
path = path.replace("{{ DATA }}", str(self.DATA))
return Path(path).expanduser()
@classmethod @classmethod
def load(cls, yaml_path: Union[str, Path]) -> 'DirConfig': def load(cls, yaml_path: Union[str, Path]) -> 'DirConfig':
yaml_path = cls._resolve_path(yaml_path, 'config') yaml_path = Path(yaml_path)
if not yaml_path.is_absolute():
yaml_path = Path(__file__).parent / "config" / yaml_path
if not yaml_path.suffix:
yaml_path = yaml_path.with_suffix('.yaml')
try: with yaml_path.open('r') as file:
with yaml_path.open('r') as file: config_data = yaml.safe_load(file)
config_data = yaml.safe_load(file)
print(f"Loaded configuration data from {yaml_path}") return cls(config_data)
# Ensure HOME is set
if 'HOME' not in config_data:
config_data['HOME'] = str(Path.home())
print(f"HOME was not in config, set to default: {config_data['HOME']}")
# Create a temporary instance to resolve placeholders
temp_instance = cls.create_dynamic_model(**config_data)
resolved_data = temp_instance.resolve_placeholders(config_data)
# Create the final instance with resolved data
return cls.create_dynamic_model(**resolved_data)
except Exception as e:
print(f"Error loading configuration: {str(e)}")
raise
@classmethod
def _resolve_path(cls, path: Union[str, Path], default_dir: str) -> Path: class Database:
def __init__(self, config_path: str):
self.config = self.load_config(config_path)
self.pool_connections = {}
self.local_ts_id = self.get_local_ts_id()
def load_config(self, config_path: str) -> Dict[str, Any]:
base_path = Path(__file__).parent.parent base_path = Path(__file__).parent.parent
path = Path(path) full_path = base_path / "sijapi" / "config" / f"{config_path}.yaml"
if not path.suffix:
path = base_path / 'sijapi' / default_dir / f"{path.name}.yaml" with open(full_path, 'r') as file:
elif not path.is_absolute(): config = yaml.safe_load(file)
path = base_path / path
return path return config
def resolve_placeholders(self, data: Any) -> Any: def get_local_ts_id(self) -> str:
if isinstance(data, dict): return os.environ.get('TS_ID')
resolved_data = {k: self.resolve_placeholders(v) for k, v in data.items()}
home_dir = Path(resolved_data.get('HOME', self.HOME)).expanduser() async def get_connection(self, ts_id: str = None):
base_dir = Path(__file__).parent.parent if ts_id is None:
data_dir = base_dir / "data" ts_id = self.local_ts_id
resolved_data['HOME'] = str(home_dir)
resolved_data['BASE'] = str(base_dir) if ts_id not in self.pool_connections:
resolved_data['DATA'] = str(data_dir) db_info = next((db for db in self.config['POOL'] if db['ts_id'] == ts_id), None)
return resolved_data if db_info is None:
elif isinstance(data, list): raise ValueError(f"No database configuration found for TS_ID: {ts_id}")
return [self.resolve_placeholders(v) for v in data]
elif isinstance(data, str): self.pool_connections[ts_id] = await asyncpg.create_pool(
return self.resolve_string_placeholders(data) host=db_info['ts_ip'],
port=db_info['db_port'],
user=db_info['db_user'],
password=db_info['db_pass'],
database=db_info['db_name'],
min_size=1,
max_size=10
)
return await self.pool_connections[ts_id].acquire()
async def release_connection(self, ts_id: str, connection):
await self.pool_connections[ts_id].release(connection)
async def get_online_servers(self) -> List[str]:
online_servers = []
for db_info in self.config['POOL']:
try:
conn = await self.get_connection(db_info['ts_id'])
await self.release_connection(db_info['ts_id'], conn)
online_servers.append(db_info['ts_id'])
except:
pass
return online_servers
async def initialize_query_tracking(self):
conn = await self.get_connection()
try:
await conn.execute("""
CREATE TABLE IF NOT EXISTS query_tracking (
id SERIAL PRIMARY KEY,
ts_id TEXT NOT NULL,
query TEXT NOT NULL,
args JSONB,
executed_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
completed_by JSONB DEFAULT '{}'::jsonb,
result_checksum TEXT
)
""")
finally:
await self.release_connection(self.local_ts_id, conn)
async def execute_read(self, query: str, *args):
conn = await self.get_connection()
try:
return await conn.fetch(query, *args)
finally:
await self.release_connection(self.local_ts_id, conn)
async def execute_write(self, query: str, *args):
# Execute write on local database
local_conn = await self.get_connection()
try:
await local_conn.execute(query, *args)
# Log the query
query_id = await local_conn.fetchval("""
INSERT INTO query_tracking (ts_id, query, args)
VALUES ($1, $2, $3)
RETURNING id
""", self.local_ts_id, query, json.dumps(args))
finally:
await self.release_connection(self.local_ts_id, local_conn)
# Calculate checksum
checksum = await self.compute_checksum(query, *args)
# Update query_tracking with checksum
await self.update_query_checksum(query_id, checksum)
# Replicate to online servers
online_servers = await self.get_online_servers()
for ts_id in online_servers:
if ts_id != self.local_ts_id:
asyncio.create_task(self._replicate_write(ts_id, query_id, query, args, checksum))
async def get_primary_server(self) -> str:
url = urljoin(self.config['URL'], '/id')
async with aiohttp.ClientSession() as session:
try:
async with session.get(url) as response:
if response.status == 200:
primary_ts_id = await response.text()
return primary_ts_id.strip()
else:
logging.error(f"Failed to get primary server. Status: {response.status}")
return None
except aiohttp.ClientError as e:
logging.error(f"Error connecting to load balancer: {str(e)}")
return None
async def get_checksum_server(self) -> dict:
primary_ts_id = await self.get_primary_server()
online_servers = await self.get_online_servers()
checksum_servers = [server for server in self.config['POOL'] if server['ts_id'] in online_servers and server['ts_id'] != primary_ts_id]
if not checksum_servers:
return next(server for server in self.config['POOL'] if server['ts_id'] == primary_ts_id)
return random.choice(checksum_servers)
async def compute_checksum(self, query: str, *args):
checksum_server = await self.get_checksum_server()
if checksum_server['ts_id'] == self.local_ts_id:
return await self._local_compute_checksum(query, *args)
else: else:
return data return await self._delegate_compute_checksum(checksum_server, query, *args)
def resolve_string_placeholders(self, value: str) -> Path: async def _local_compute_checksum(self, query: str, *args):
pattern = r'\{\{\s*([^}]+)\s*\}\}' conn = await self.get_connection()
matches = re.findall(pattern, value) try:
result = await conn.fetch(query, *args)
checksum = hashlib.md5(str(result).encode()).hexdigest()
return checksum
finally:
await self.release_connection(self.local_ts_id, conn)
async def _delegate_compute_checksum(self, server: dict, query: str, *args):
url = f"http://{server['ts_ip']}:{server['app_port']}/sync/checksum"
for match in matches: async with aiohttp.ClientSession() as session:
if match == 'HOME': try:
replacement = str(self.HOME) async with session.post(url, json={"query": query, "args": list(args)}) as response:
elif match == 'BASE': if response.status == 200:
replacement = str(Path(__file__).parent.parent) result = await response.json()
elif match == 'DATA': return result['checksum']
replacement = str(Path(__file__).parent.parent / "data") else:
elif hasattr(self, match): logging.error(f"Failed to get checksum from {server['ts_id']}. Status: {response.status}")
replacement = str(getattr(self, match)) return await self._local_compute_checksum(query, *args)
else: except aiohttp.ClientError as e:
replacement = value logging.error(f"Error connecting to {server['ts_id']} for checksum: {str(e)}")
return await self._local_compute_checksum(query, *args)
async def update_query_checksum(self, query_id: int, checksum: str):
conn = await self.get_connection()
try:
await conn.execute("""
UPDATE query_tracking
SET result_checksum = $1
WHERE id = $2
""", checksum, query_id)
finally:
await self.release_connection(self.local_ts_id, conn)
async def _replicate_write(self, ts_id: str, query_id: int, query: str, args: tuple, expected_checksum: str):
try:
conn = await self.get_connection(ts_id)
try:
await conn.execute(query, *args)
actual_checksum = await self.compute_checksum(query, *args)
if actual_checksum != expected_checksum:
raise ValueError(f"Checksum mismatch on {ts_id}")
await self.mark_query_completed(query_id, ts_id)
finally:
await self.release_connection(ts_id, conn)
except Exception as e:
logging.error(f"Failed to replicate write on {ts_id}: {str(e)}")
async def mark_query_completed(self, query_id: int, ts_id: str):
conn = await self.get_connection()
try:
await conn.execute("""
UPDATE query_tracking
SET completed_by = completed_by || jsonb_build_object($1, true)
WHERE id = $2
""", ts_id, query_id)
finally:
await self.release_connection(self.local_ts_id, conn)
async def sync_local_server(self):
conn = await self.get_connection()
try:
last_synced_id = await conn.fetchval("""
SELECT COALESCE(MAX(id), 0) FROM query_tracking
WHERE completed_by ? $1
""", self.local_ts_id)
unexecuted_queries = await conn.fetch("""
SELECT id, query, args, result_checksum
FROM query_tracking
WHERE id > $1
ORDER BY id
""", last_synced_id)
for query in unexecuted_queries:
try:
await conn.execute(query['query'], *json.loads(query['args']))
actual_checksum = await self.compute_checksum(query['query'], *json.loads(query['args']))
if actual_checksum != query['result_checksum']:
raise ValueError(f"Checksum mismatch for query ID {query['id']}")
await self.mark_query_completed(query['id'], self.local_ts_id)
except Exception as e:
logging.error(f"Failed to execute query ID {query['id']} during local sync: {str(e)}")
logging.info(f"Local server sync completed. Executed {len(unexecuted_queries)} queries.")
finally:
await self.release_connection(self.local_ts_id, conn)
async def purge_completed_queries(self):
conn = await self.get_connection()
try:
all_ts_ids = [db['ts_id'] for db in self.config['POOL']]
result = await conn.execute("""
WITH consecutive_completed AS (
SELECT id,
row_number() OVER (ORDER BY id) AS rn
FROM query_tracking
WHERE completed_by ?& $1
)
DELETE FROM query_tracking
WHERE id IN (
SELECT id
FROM consecutive_completed
WHERE rn = (SELECT MAX(rn) FROM consecutive_completed)
)
""", all_ts_ids)
deleted_count = int(result.split()[-1])
logging.info(f"Purged {deleted_count} completed queries.")
finally:
await self.release_connection(self.local_ts_id, conn)
async def close(self):
for pool in self.pool_connections.values():
await pool.close()
value = value.replace('{{' + match + '}}', replacement)
return Path(value).expanduser()
@classmethod # Configuration class for API & Database methods.
def create_dynamic_model(cls, **data):
DynamicModel = create_model(
f'Dynamic{cls.__name__}',
__base__=cls,
**{k: (Path, v) for k, v in data.items()}
)
return DynamicModel(**data)
class Config:
arbitrary_types_allowed = True
# Configuration class for API & Database methods.
class APIConfig(BaseModel): class APIConfig(BaseModel):
HOST: str HOST: str
PORT: int PORT: int

View file

@ -94,7 +94,7 @@ TRUSTED_SUBNETS=127.0.0.1/32,10.13.37.0/24,100.64.64.0/24
# ────────── # ──────────
# #
#─── router selection: ──────────────────────────────────────────────────────────── #─── router selection: ────────────────────────────────────────────────────────────
ROUTERS=asr,cal,cf,email,health,llm,loc,note,rag,img,serve,time,tts,weather ROUTERS=asr,cal,cf,email,llm,loc,note,rag,img,serve,sys,time,tts,weather
UNLOADED=ig UNLOADED=ig
#─── notes: ────────────────────────────────────────────────────────────────────── #─── notes: ──────────────────────────────────────────────────────────────────────
# #

View file

@ -27,7 +27,6 @@ MODULES:
dist: off dist: off
email: on email: on
gis: on gis: on
health: on
ig: off ig: off
img: on img: on
llm: on llm: on
@ -36,6 +35,7 @@ MODULES:
rag: off rag: off
scrape: on scrape: on
serve: on serve: on
sys: on
timing: on timing: on
tts: on tts: on
weather: on weather: on

View file

@ -270,39 +270,35 @@ Generate a heatmap for the given date range and save it as a PNG file using Foli
else: else:
end_date = start_date.replace(hour=23, minute=59, second=59) end_date = start_date.replace(hour=23, minute=59, second=59)
# Fetch locations
locations = await fetch_locations(start_date, end_date) locations = await fetch_locations(start_date, end_date)
if not locations: if not locations:
raise ValueError("No locations found for the given date range") raise ValueError("No locations found for the given date range")
# Create map
m = folium.Map() m = folium.Map()
# Prepare heatmap data
heat_data = [[loc.latitude, loc.longitude] for loc in locations] heat_data = [[loc.latitude, loc.longitude] for loc in locations]
# Add heatmap layer
HeatMap(heat_data).add_to(m) HeatMap(heat_data).add_to(m)
# Fit the map to the bounds of all locations
bounds = [ bounds = [
[min(loc.latitude for loc in locations), min(loc.longitude for loc in locations)], [min(loc.latitude for loc in locations), min(loc.longitude for loc in locations)],
[max(loc.latitude for loc in locations), max(loc.longitude for loc in locations)] [max(loc.latitude for loc in locations), max(loc.longitude for loc in locations)]
] ]
m.fit_bounds(bounds) m.fit_bounds(bounds)
# Generate output path if not provided try:
if output_path is None: if output_path is None:
output_path, relative_path = assemble_journal_path(end_date, filename="map", extension=".png", no_timestamp=True) output_path, relative_path = assemble_journal_path(end_date, filename="map", extension=".png", no_timestamp=True)
# Save the map as PNG m.save(str(output_path))
m.save(str(output_path))
info(f"Heatmap saved as PNG: {output_path}")
info(f"Heatmap saved as PNG: {output_path}") return output_path
return output_path
except Exception as e:
err(f"Error saving heatmap: {str(e)}")
raise
except Exception as e: except Exception as e:
err(f"Error generating and saving heatmap: {str(e)}") err(f"Error generating heatmap: {str(e)}")
raise raise
async def generate_map(start_date: datetime, end_date: datetime, max_points: int): async def generate_map(start_date: datetime, end_date: datetime, max_points: int):

View file

@ -337,8 +337,8 @@ Obsidian helper. Takes a datetime and creates a new daily note. Note: it uses th
_, note_path = assemble_journal_path(date_time, filename="Notes", extension=".md", no_timestamp = True) _, note_path = assemble_journal_path(date_time, filename="Notes", extension=".md", no_timestamp = True)
note_embed = f"![[{note_path}]]" note_embed = f"![[{note_path}]]"
_, map_path = assemble_journal_path(date_time, filename="Map", extension=".png", no_timestamp = True) absolute_map_path, map_path = assemble_journal_path(date_time, filename="Map", extension=".png", no_timestamp = True)
map = await gis.generate_and_save_heatmap(date_time, output_path=map_path) map = await gis.generate_and_save_heatmap(date_time, output_path=absolute_map_path)
map_embed = f"![[{map_path}]]" map_embed = f"![[{map_path}]]"
_, banner_path = assemble_journal_path(date_time, filename="Banner", extension=".jpg", no_timestamp = True) _, banner_path = assemble_journal_path(date_time, filename="Banner", extension=".jpg", no_timestamp = True)

View file

@ -1,9 +1,7 @@
''' '''
Health check module. /health returns `'status': 'ok'`, /id returns TS_ID, /routers responds with a list of the active routers, /ip responds with the device's local IP, /ts_ip responds with its tailnet IP, and /wan_ip responds with WAN IP. System module. /health returns `'status': 'ok'`, /id returns TS_ID, /routers responds with a list of the active routers, /ip responds with the device's local IP, /ts_ip responds with its tailnet IP, and /wan_ip responds with WAN IP.
Depends on:
TS_ID, LOGGER, SUBNET_BROADCAST
''' '''
#routers/health.py #routers/sys.py
import os import os
import httpx import httpx
@ -12,7 +10,7 @@ from fastapi import APIRouter
from tailscale import Tailscale from tailscale import Tailscale
from sijapi import L, API, TS_ID, SUBNET_BROADCAST from sijapi import L, API, TS_ID, SUBNET_BROADCAST
health = APIRouter(tags=["public", "trusted", "private"]) sys = APIRouter(tags=["public", "trusted", "private"])
logger = L.get_module_logger("health") logger = L.get_module_logger("health")
def debug(text: str): logger.debug(text) def debug(text: str): logger.debug(text)
def info(text: str): logger.info(text) def info(text: str): logger.info(text)
@ -20,20 +18,20 @@ def warn(text: str): logger.warning(text)
def err(text: str): logger.error(text) def err(text: str): logger.error(text)
def crit(text: str): logger.critical(text) def crit(text: str): logger.critical(text)
@health.get("/health") @sys.get("/health")
def get_health(): def get_health():
return {"status": "ok"} return {"status": "ok"}
@health.get("/id") @sys.get("/id")
def get_health() -> str: def get_health() -> str:
return TS_ID return TS_ID
@health.get("/routers") @sys.get("/routers")
def get_routers() -> str: def get_routers() -> str:
active_modules = [module for module, is_active in API.MODULES.__dict__.items() if is_active] active_modules = [module for module, is_active in API.MODULES.__dict__.items() if is_active]
return active_modules return active_modules
@health.get("/ip") @sys.get("/ip")
def get_local_ip(): def get_local_ip():
"""Get the server's local IP address.""" """Get the server's local IP address."""
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
@ -46,7 +44,7 @@ def get_local_ip():
s.close() s.close()
return IP return IP
@health.get("/wan_ip") @sys.get("/wan_ip")
async def get_wan_ip(): async def get_wan_ip():
"""Get the WAN IP address using Mullvad's API.""" """Get the WAN IP address using Mullvad's API."""
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
@ -59,7 +57,7 @@ async def get_wan_ip():
err(f"Error fetching WAN IP: {e}") err(f"Error fetching WAN IP: {e}")
return "Unavailable" return "Unavailable"
@health.get("/ts_ip") @sys.get("/ts_ip")
async def get_tailscale_ip(): async def get_tailscale_ip():
"""Get the Tailscale IP address.""" """Get the Tailscale IP address."""
tailnet = os.getenv("TAILNET") tailnet = os.getenv("TAILNET")

View file

@ -203,7 +203,7 @@ async def generate_speech(
speed: float = 1.1, speed: float = 1.1,
podcast: bool = False, podcast: bool = False,
title: str = None, title: str = None,
output_dir = None output_dir = None,
) -> str: ) -> str:
debug(f"Entering generate_speech function") debug(f"Entering generate_speech function")
debug(f"API.EXTENSIONS: {API.EXTENSIONS}") debug(f"API.EXTENSIONS: {API.EXTENSIONS}")
@ -213,14 +213,14 @@ async def generate_speech(
debug(f"Type of Tts: {type(Tts)}") debug(f"Type of Tts: {type(Tts)}")
debug(f"Dir of Tts: {dir(Tts)}") debug(f"Dir of Tts: {dir(Tts)}")
output_dir = Path(output_dir) if output_dir else TTS_OUTPUT_DIR
if not output_dir.exists(): use_output_dir = Path(output_dir) if output_dir else TTS_OUTPUT_DIR
output_dir.mkdir(parents=True) if not use_output_dir.exists(): use_output_dir.mkdir(parents=True)
try: try:
model = model if model else await get_model(voice, voice_file) model = model if model else await get_model(voice, voice_file)
title = title if title else "TTS audio" title = title if title else "TTS audio"
output_path = output_dir / f"{dt_datetime.now().strftime('%Y%m%d%H%M%S')} {title}.wav" output_path = use_output_dir / f"{dt_datetime.now().strftime('%Y%m%d%H%M%S')} {title}.wav"
debug(f"Model: {model}") debug(f"Model: {model}")
debug(f"Voice: {voice}") debug(f"Voice: {voice}")
@ -228,7 +228,7 @@ async def generate_speech(
if model == "eleven_turbo_v2" and getattr(API.EXTENSIONS, 'elevenlabs', False): if model == "eleven_turbo_v2" and getattr(API.EXTENSIONS, 'elevenlabs', False):
info("Using ElevenLabs.") info("Using ElevenLabs.")
audio_file_path = await elevenlabs_tts(model, text, voice, title, output_dir) audio_file_path = await elevenlabs_tts(model, text, voice, title, use_output_dir)
elif getattr(API.EXTENSIONS, 'xtts', False): elif getattr(API.EXTENSIONS, 'xtts', False):
info("Using XTTS2") info("Using XTTS2")
audio_file_path = await local_tts(text, speed, voice, voice_file, podcast, bg_tasks, title, output_path) audio_file_path = await local_tts(text, speed, voice, voice_file, podcast, bg_tasks, title, output_path)
@ -244,24 +244,29 @@ async def generate_speech(
warn(f"No file exists at {audio_file_path}") warn(f"No file exists at {audio_file_path}")
if podcast: if podcast:
podcast_path = Dir.PODCAST / audio_file_path podcast_path = Dir.PODCAST / audio_file_path.name
shutil.copy(audio_file_path, podcast_path)
if podcast_path.exists():
info(f"Saved to podcast path: {podcast_path}")
else:
warn(f"Podcast mode enabled, but failed to save to {podcast_path}")
if podcast_path != audio_file_path: if podcast_path != audio_file_path:
info(f"Podcast mode enabled, so we will remove {audio_file_path}") shutil.copy(audio_file_path, podcast_path)
bg_tasks.add_task(os.remove, audio_file_path) if podcast_path.exists():
info(f"Saved to podcast path: {podcast_path}")
else:
warn(f"Podcast mode enabled, but failed to save to {podcast_path}")
if output_dir and Path(output_dir) == use_output_dir:
debug(f"Keeping {audio_file_path} because it was specified")
else:
info(f"Podcast mode enabled and output_dir not specified so we will remove {audio_file_path}")
bg_tasks.add_task(os.remove, audio_file_path)
else: else:
warn(f"Podcast path set to same as audio file path...") warn(f"Podcast path is the same as audio file path. Using existing file.")
return podcast_path return podcast_path
return audio_file_path return audio_file_path
except Exception as e: except Exception as e:
err(f"Failed to generate speech: {e}") err(f"Failed to generate speech: {e}")
err(f"Traceback: {traceback.format_exc()}") err(f"Traceback: {traceback.format_exc()}")