Auto-update: Thu Aug 8 15:30:58 PDT 2024

This commit is contained in:
sanj 2024-08-08 15:30:58 -07:00
parent 3d7ae743e6
commit f685e297fc
2 changed files with 116 additions and 109 deletions

View file

@ -125,11 +125,11 @@ async def hook_alert(request: Request):
async def notify(alert: str): async def notify(alert: str):
fail = True fail = True
try: try:
if API.EXTENSIONS.shellfish == "on" or API.EXTENSIONS.shellfish == True: if API.EXTENSIONS.shellfish == True:
await notify_shellfish(alert) await notify_shellfish(alert)
fail = False fail = False
if API.EXTENSIONS.macnotify == "on" or API.EXTENSIONS.macnotify == True: if API.EXTENSIONS.macnotify == True:
if TS_ID == MAC_ID: if TS_ID == MAC_ID:
await notify_local(alert) await notify_local(alert)
fail = False fail = False
@ -165,7 +165,7 @@ async def notify_remote(host: str, message: str, username: str = None, password:
ssh.close() ssh.close()
if API.EXTENSIONS.shellfish == "on" or API.EXTENSIONS.shellfish == True: if API.EXTENSIONS.shellfish == True:
async def notify_shellfish(alert: str): async def notify_shellfish(alert: str):
key = "d7e810e7601cd296a05776c169b4fe97a6a5ee1fd46abe38de54f415732b3f4b" key = "d7e810e7601cd296a05776c169b4fe97a6a5ee1fd46abe38de54f415732b3f4b"
user = "WuqPwm1VpGijF4U5AnIKzqNMVWGioANTRjJoonPm" user = "WuqPwm1VpGijF4U5AnIKzqNMVWGioANTRjJoonPm"
@ -250,7 +250,7 @@ if API.EXTENSIONS.shellfish == "on" or API.EXTENSIONS.shellfish == True:
return result.stdout return result.stdout
if API.EXTENSIONS.courtlistener == "on" or API.EXTENSIONS.courtlistener == True: if API.EXTENSIONS.courtlistener == True:
with open(CASETABLE_PATH, 'r') as file: with open(CASETABLE_PATH, 'r') as file:
CASETABLE = json.load(file) CASETABLE = json.load(file)

View file

@ -14,7 +14,6 @@ import asyncio
from pydantic import BaseModel from pydantic import BaseModel
from typing import Optional, Union, List from typing import Optional, Union, List
from pydub import AudioSegment from pydub import AudioSegment
from TTS.api import TTS
from pathlib import Path from pathlib import Path
from datetime import datetime as dt_datetime from datetime import datetime as dt_datetime
from time import time from time import time
@ -144,12 +143,15 @@ async def generate_speech(
title = title if title else "TTS audio" title = title if title else "TTS audio"
output_path = output_dir / f"{dt_datetime.now().strftime('%Y%m%d%H%M%S')} {title}.wav" output_path = output_dir / f"{dt_datetime.now().strftime('%Y%m%d%H%M%S')} {title}.wav"
if model == "eleven_turbo_v2": if model == "eleven_turbo_v2" and API.EXTENSIONS.elevenlabs == True:
info("Using ElevenLabs.") info("Using ElevenLabs.")
audio_file_path = await elevenlabs_tts(model, text, voice, title, output_dir) audio_file_path = await elevenlabs_tts(model, text, voice, title, output_dir)
else: # if model == "xtts": elif API.EXTENSIONS.xtts == True:
info("Using XTTS2") info("Using XTTS2")
audio_file_path = await local_tts(text, speed, voice, voice_file, podcast, bg_tasks, title, output_path) audio_file_path = await local_tts(text, speed, voice, voice_file, podcast, bg_tasks, title, output_path)
else:
err(f"No TTS module enabled!")
return None
if not audio_file_path: if not audio_file_path:
raise ValueError("TTS generation failed: audio_file_path is empty or None") raise ValueError("TTS generation failed: audio_file_path is empty or None")
@ -183,39 +185,25 @@ async def generate_speech(
async def get_model(voice: str = None, voice_file: UploadFile = None): async def get_model(voice: str = None, voice_file: UploadFile = None):
if voice_file or (voice and await select_voice(voice)): if (voice_file or (voice and await select_voice(voice))) and API.EXTENSIONS.xtts == True:
return "xtts" return "xtts"
elif voice and await determine_voice_id(voice): elif voice and await determine_voice_id(voice) and API.EXTENSIONS.elevenlabs == True:
return "eleven_turbo_v2" return "eleven_turbo_v2"
else: else:
raise HTTPException(status_code=400, detail="No model or voice specified") err(f"No model or voice specified, or no TTS module loaded")
raise HTTPException(status_code=400, detail="No model or voice specified, or no TTS module loaded")
async def determine_voice_id(voice_name: str) -> str: async def determine_voice_id(voice_name: str) -> str:
debug(f"Searching for voice id for {voice_name}") debug(f"Searching for voice id for {voice_name}")
# Todo: move this to tts.yaml
hardcoded_voices = {
"alloy": "E3A1KVbKoWSIKSZwSUsW",
"echo": "b42GBisbu9r5m5n6pHF7",
"fable": "KAX2Y6tTs0oDWq7zZXW7",
"onyx": "clQb8NxY08xZ6mX6wCPE",
"nova": "6TayTBKLMOsghG7jYuMX",
"shimmer": "E7soeOyjpmuZFurvoxZ2",
"Luna": "6TayTBKLMOsghG7jYuMX",
"Sangye": "E7soeOyjpmuZFurvoxZ2",
"Herzog": "KAX2Y6tTs0oDWq7zZXW7",
"Attenborough": "b42GBisbu9r5m5n6pHF7",
"Victoria": "7UBkHqZOtFRLq6cSMQQg"
}
if voice_name in Tts.elevenlabs.voices: if voice_name in Tts.elevenlabs.voices:
voice_id = hardcoded_voices[voice_name] voice_id = hardcoded_voices[voice_name]
debug(f"Found voice ID - {voice_id}") debug(f"Found voice ID - {voice_id}")
return voice_id return voice_id
debug(f"Requested voice not among the hardcoded options.. checking with 11L next.") debug(f"Requested voice not among the voices specified in config/tts.yaml. Checking with 11L next.")
url = "https://api.elevenlabs.io/v1/voices" url = "https://api.elevenlabs.io/v1/voices"
headers = {"xi-api-key": ELEVENLABS_API_KEY} headers = {"xi-api-key": ELEVENLABS_API_KEY}
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
@ -231,12 +219,13 @@ async def determine_voice_id(voice_name: str) -> str:
except Exception as e: except Exception as e:
err(f"Error determining voice ID: {str(e)}") err(f"Error determining voice ID: {str(e)}")
# as a last fallback, rely on David Attenborough; move this to tts.yaml warn(f"Voice \'{voice_name}\' not found; attempting to use the default specified in config/tts.yaml: {Tts.elevenlabs.default}")
return "b42GBisbu9r5m5n6pHF7" return await determine_voice_id(Tts.elevenlabs.default)
async def elevenlabs_tts(model: str, input_text: str, voice: str, title: str = None, output_dir: str = None): async def elevenlabs_tts(model: str, input_text: str, voice: str, title: str = None, output_dir: str = None):
if API.EXTENSIONS.elevenlabs == True:
voice_id = await determine_voice_id(voice) voice_id = await determine_voice_id(voice)
url = f"https://api.elevenlabs.io/v1/text-to-speech/{voice_id}" url = f"https://api.elevenlabs.io/v1/text-to-speech/{voice_id}"
@ -264,6 +253,9 @@ async def elevenlabs_tts(model: str, input_text: str, voice: str, title: str = N
err(f"Error from Elevenlabs API: {e}") err(f"Error from Elevenlabs API: {e}")
raise HTTPException(status_code=response.status_code, detail="Error from ElevenLabs API") raise HTTPException(status_code=response.status_code, detail="Error from ElevenLabs API")
else:
warn(f"elevenlabs_tts called but ElevenLabs module disabled in config/api.yaml!")
@ -324,6 +316,10 @@ async def local_tts(
title: str = None, title: str = None,
output_path: Optional[Path] = None output_path: Optional[Path] = None
) -> str: ) -> str:
if API.EXTENSIONS.xtts == True:
from TTS.api import TTS
if output_path: if output_path:
file_path = Path(output_path) file_path = Path(output_path)
else: else:
@ -375,6 +371,10 @@ async def local_tts(
return str(file_path) return str(file_path)
else:
warn(f"local_tts called but xtts module disabled!")
return None
async def stream_tts(text_content: str, speed: float, voice: str, voice_file) -> StreamingResponse: async def stream_tts(text_content: str, speed: float, voice: str, voice_file) -> StreamingResponse:
@ -394,6 +394,10 @@ async def stream_tts(text_content: str, speed: float, voice: str, voice_file) ->
async def generate_tts(text: str, speed: float, voice_file_path: str) -> str: async def generate_tts(text: str, speed: float, voice_file_path: str) -> str:
if API.EXTENSIONS.xtts == True:
from TTS.api import TTS
output_dir = tempfile.mktemp(suffix=".wav", dir=tempfile.gettempdir()) output_dir = tempfile.mktemp(suffix=".wav", dir=tempfile.gettempdir())
XTTS = TTS(model_name=Tts.xtts.model).to(DEVICE) XTTS = TTS(model_name=Tts.xtts.model).to(DEVICE)
@ -401,6 +405,10 @@ async def generate_tts(text: str, speed: float, voice_file_path: str) -> str:
return output_dir return output_dir
else:
warn(f"generate_tts called but xtts module disabled!")
return None
async def get_audio_stream(model: str, input_text: str, voice: str): async def get_audio_stream(model: str, input_text: str, voice: str):
voice_id = await determine_voice_id(voice) voice_id = await determine_voice_id(voice)
@ -455,7 +463,6 @@ def clean_text_for_tts(text: str) -> str:
debug(f"No text received.") debug(f"No text received.")
def copy_to_podcast_dir(file_path): def copy_to_podcast_dir(file_path):
try: try:
# Extract the file name from the file path # Extract the file name from the file path