Auto-update: Thu Aug 8 15:30:58 PDT 2024
This commit is contained in:
parent
3d7ae743e6
commit
f685e297fc
2 changed files with 116 additions and 109 deletions
|
@ -125,11 +125,11 @@ async def hook_alert(request: Request):
|
||||||
async def notify(alert: str):
|
async def notify(alert: str):
|
||||||
fail = True
|
fail = True
|
||||||
try:
|
try:
|
||||||
if API.EXTENSIONS.shellfish == "on" or API.EXTENSIONS.shellfish == True:
|
if API.EXTENSIONS.shellfish == True:
|
||||||
await notify_shellfish(alert)
|
await notify_shellfish(alert)
|
||||||
fail = False
|
fail = False
|
||||||
|
|
||||||
if API.EXTENSIONS.macnotify == "on" or API.EXTENSIONS.macnotify == True:
|
if API.EXTENSIONS.macnotify == True:
|
||||||
if TS_ID == MAC_ID:
|
if TS_ID == MAC_ID:
|
||||||
await notify_local(alert)
|
await notify_local(alert)
|
||||||
fail = False
|
fail = False
|
||||||
|
@ -165,7 +165,7 @@ async def notify_remote(host: str, message: str, username: str = None, password:
|
||||||
ssh.close()
|
ssh.close()
|
||||||
|
|
||||||
|
|
||||||
if API.EXTENSIONS.shellfish == "on" or API.EXTENSIONS.shellfish == True:
|
if API.EXTENSIONS.shellfish == True:
|
||||||
async def notify_shellfish(alert: str):
|
async def notify_shellfish(alert: str):
|
||||||
key = "d7e810e7601cd296a05776c169b4fe97a6a5ee1fd46abe38de54f415732b3f4b"
|
key = "d7e810e7601cd296a05776c169b4fe97a6a5ee1fd46abe38de54f415732b3f4b"
|
||||||
user = "WuqPwm1VpGijF4U5AnIKzqNMVWGioANTRjJoonPm"
|
user = "WuqPwm1VpGijF4U5AnIKzqNMVWGioANTRjJoonPm"
|
||||||
|
@ -250,7 +250,7 @@ if API.EXTENSIONS.shellfish == "on" or API.EXTENSIONS.shellfish == True:
|
||||||
return result.stdout
|
return result.stdout
|
||||||
|
|
||||||
|
|
||||||
if API.EXTENSIONS.courtlistener == "on" or API.EXTENSIONS.courtlistener == True:
|
if API.EXTENSIONS.courtlistener == True:
|
||||||
with open(CASETABLE_PATH, 'r') as file:
|
with open(CASETABLE_PATH, 'r') as file:
|
||||||
CASETABLE = json.load(file)
|
CASETABLE = json.load(file)
|
||||||
|
|
||||||
|
|
|
@ -14,7 +14,6 @@ import asyncio
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import Optional, Union, List
|
from typing import Optional, Union, List
|
||||||
from pydub import AudioSegment
|
from pydub import AudioSegment
|
||||||
from TTS.api import TTS
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from datetime import datetime as dt_datetime
|
from datetime import datetime as dt_datetime
|
||||||
from time import time
|
from time import time
|
||||||
|
@ -144,12 +143,15 @@ async def generate_speech(
|
||||||
title = title if title else "TTS audio"
|
title = title if title else "TTS audio"
|
||||||
output_path = output_dir / f"{dt_datetime.now().strftime('%Y%m%d%H%M%S')} {title}.wav"
|
output_path = output_dir / f"{dt_datetime.now().strftime('%Y%m%d%H%M%S')} {title}.wav"
|
||||||
|
|
||||||
if model == "eleven_turbo_v2":
|
if model == "eleven_turbo_v2" and API.EXTENSIONS.elevenlabs == True:
|
||||||
info("Using ElevenLabs.")
|
info("Using ElevenLabs.")
|
||||||
audio_file_path = await elevenlabs_tts(model, text, voice, title, output_dir)
|
audio_file_path = await elevenlabs_tts(model, text, voice, title, output_dir)
|
||||||
else: # if model == "xtts":
|
elif API.EXTENSIONS.xtts == True:
|
||||||
info("Using XTTS2")
|
info("Using XTTS2")
|
||||||
audio_file_path = await local_tts(text, speed, voice, voice_file, podcast, bg_tasks, title, output_path)
|
audio_file_path = await local_tts(text, speed, voice, voice_file, podcast, bg_tasks, title, output_path)
|
||||||
|
else:
|
||||||
|
err(f"No TTS module enabled!")
|
||||||
|
return None
|
||||||
|
|
||||||
if not audio_file_path:
|
if not audio_file_path:
|
||||||
raise ValueError("TTS generation failed: audio_file_path is empty or None")
|
raise ValueError("TTS generation failed: audio_file_path is empty or None")
|
||||||
|
@ -183,39 +185,25 @@ async def generate_speech(
|
||||||
|
|
||||||
|
|
||||||
async def get_model(voice: str = None, voice_file: UploadFile = None):
|
async def get_model(voice: str = None, voice_file: UploadFile = None):
|
||||||
if voice_file or (voice and await select_voice(voice)):
|
if (voice_file or (voice and await select_voice(voice))) and API.EXTENSIONS.xtts == True:
|
||||||
return "xtts"
|
return "xtts"
|
||||||
|
|
||||||
elif voice and await determine_voice_id(voice):
|
elif voice and await determine_voice_id(voice) and API.EXTENSIONS.elevenlabs == True:
|
||||||
return "eleven_turbo_v2"
|
return "eleven_turbo_v2"
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise HTTPException(status_code=400, detail="No model or voice specified")
|
err(f"No model or voice specified, or no TTS module loaded")
|
||||||
|
raise HTTPException(status_code=400, detail="No model or voice specified, or no TTS module loaded")
|
||||||
|
|
||||||
async def determine_voice_id(voice_name: str) -> str:
|
async def determine_voice_id(voice_name: str) -> str:
|
||||||
debug(f"Searching for voice id for {voice_name}")
|
debug(f"Searching for voice id for {voice_name}")
|
||||||
|
|
||||||
# Todo: move this to tts.yaml
|
|
||||||
hardcoded_voices = {
|
|
||||||
"alloy": "E3A1KVbKoWSIKSZwSUsW",
|
|
||||||
"echo": "b42GBisbu9r5m5n6pHF7",
|
|
||||||
"fable": "KAX2Y6tTs0oDWq7zZXW7",
|
|
||||||
"onyx": "clQb8NxY08xZ6mX6wCPE",
|
|
||||||
"nova": "6TayTBKLMOsghG7jYuMX",
|
|
||||||
"shimmer": "E7soeOyjpmuZFurvoxZ2",
|
|
||||||
"Luna": "6TayTBKLMOsghG7jYuMX",
|
|
||||||
"Sangye": "E7soeOyjpmuZFurvoxZ2",
|
|
||||||
"Herzog": "KAX2Y6tTs0oDWq7zZXW7",
|
|
||||||
"Attenborough": "b42GBisbu9r5m5n6pHF7",
|
|
||||||
"Victoria": "7UBkHqZOtFRLq6cSMQQg"
|
|
||||||
}
|
|
||||||
|
|
||||||
if voice_name in Tts.elevenlabs.voices:
|
if voice_name in Tts.elevenlabs.voices:
|
||||||
voice_id = hardcoded_voices[voice_name]
|
voice_id = hardcoded_voices[voice_name]
|
||||||
debug(f"Found voice ID - {voice_id}")
|
debug(f"Found voice ID - {voice_id}")
|
||||||
return voice_id
|
return voice_id
|
||||||
|
|
||||||
debug(f"Requested voice not among the hardcoded options.. checking with 11L next.")
|
debug(f"Requested voice not among the voices specified in config/tts.yaml. Checking with 11L next.")
|
||||||
url = "https://api.elevenlabs.io/v1/voices"
|
url = "https://api.elevenlabs.io/v1/voices"
|
||||||
headers = {"xi-api-key": ELEVENLABS_API_KEY}
|
headers = {"xi-api-key": ELEVENLABS_API_KEY}
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
|
@ -231,12 +219,13 @@ async def determine_voice_id(voice_name: str) -> str:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
err(f"Error determining voice ID: {str(e)}")
|
err(f"Error determining voice ID: {str(e)}")
|
||||||
|
|
||||||
# as a last fallback, rely on David Attenborough; move this to tts.yaml
|
warn(f"Voice \'{voice_name}\' not found; attempting to use the default specified in config/tts.yaml: {Tts.elevenlabs.default}")
|
||||||
return "b42GBisbu9r5m5n6pHF7"
|
return await determine_voice_id(Tts.elevenlabs.default)
|
||||||
|
|
||||||
|
|
||||||
async def elevenlabs_tts(model: str, input_text: str, voice: str, title: str = None, output_dir: str = None):
|
async def elevenlabs_tts(model: str, input_text: str, voice: str, title: str = None, output_dir: str = None):
|
||||||
|
|
||||||
|
if API.EXTENSIONS.elevenlabs == True:
|
||||||
voice_id = await determine_voice_id(voice)
|
voice_id = await determine_voice_id(voice)
|
||||||
|
|
||||||
url = f"https://api.elevenlabs.io/v1/text-to-speech/{voice_id}"
|
url = f"https://api.elevenlabs.io/v1/text-to-speech/{voice_id}"
|
||||||
|
@ -264,6 +253,9 @@ async def elevenlabs_tts(model: str, input_text: str, voice: str, title: str = N
|
||||||
err(f"Error from Elevenlabs API: {e}")
|
err(f"Error from Elevenlabs API: {e}")
|
||||||
raise HTTPException(status_code=response.status_code, detail="Error from ElevenLabs API")
|
raise HTTPException(status_code=response.status_code, detail="Error from ElevenLabs API")
|
||||||
|
|
||||||
|
else:
|
||||||
|
warn(f"elevenlabs_tts called but ElevenLabs module disabled in config/api.yaml!")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -324,6 +316,10 @@ async def local_tts(
|
||||||
title: str = None,
|
title: str = None,
|
||||||
output_path: Optional[Path] = None
|
output_path: Optional[Path] = None
|
||||||
) -> str:
|
) -> str:
|
||||||
|
|
||||||
|
if API.EXTENSIONS.xtts == True:
|
||||||
|
from TTS.api import TTS
|
||||||
|
|
||||||
if output_path:
|
if output_path:
|
||||||
file_path = Path(output_path)
|
file_path = Path(output_path)
|
||||||
else:
|
else:
|
||||||
|
@ -375,6 +371,10 @@ async def local_tts(
|
||||||
|
|
||||||
return str(file_path)
|
return str(file_path)
|
||||||
|
|
||||||
|
else:
|
||||||
|
warn(f"local_tts called but xtts module disabled!")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def stream_tts(text_content: str, speed: float, voice: str, voice_file) -> StreamingResponse:
|
async def stream_tts(text_content: str, speed: float, voice: str, voice_file) -> StreamingResponse:
|
||||||
|
@ -394,6 +394,10 @@ async def stream_tts(text_content: str, speed: float, voice: str, voice_file) ->
|
||||||
|
|
||||||
|
|
||||||
async def generate_tts(text: str, speed: float, voice_file_path: str) -> str:
|
async def generate_tts(text: str, speed: float, voice_file_path: str) -> str:
|
||||||
|
|
||||||
|
if API.EXTENSIONS.xtts == True:
|
||||||
|
from TTS.api import TTS
|
||||||
|
|
||||||
output_dir = tempfile.mktemp(suffix=".wav", dir=tempfile.gettempdir())
|
output_dir = tempfile.mktemp(suffix=".wav", dir=tempfile.gettempdir())
|
||||||
|
|
||||||
XTTS = TTS(model_name=Tts.xtts.model).to(DEVICE)
|
XTTS = TTS(model_name=Tts.xtts.model).to(DEVICE)
|
||||||
|
@ -401,6 +405,10 @@ async def generate_tts(text: str, speed: float, voice_file_path: str) -> str:
|
||||||
|
|
||||||
return output_dir
|
return output_dir
|
||||||
|
|
||||||
|
else:
|
||||||
|
warn(f"generate_tts called but xtts module disabled!")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
async def get_audio_stream(model: str, input_text: str, voice: str):
|
async def get_audio_stream(model: str, input_text: str, voice: str):
|
||||||
voice_id = await determine_voice_id(voice)
|
voice_id = await determine_voice_id(voice)
|
||||||
|
@ -455,7 +463,6 @@ def clean_text_for_tts(text: str) -> str:
|
||||||
debug(f"No text received.")
|
debug(f"No text received.")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def copy_to_podcast_dir(file_path):
|
def copy_to_podcast_dir(file_path):
|
||||||
try:
|
try:
|
||||||
# Extract the file name from the file path
|
# Extract the file name from the file path
|
||||||
|
|
Loading…
Reference in a new issue