Auto-update: Thu Aug 8 16:54:57 PDT 2024

This commit is contained in:
sanj 2024-08-08 16:54:57 -07:00
parent abd1448452
commit df30d1a1ac

View file

@ -123,6 +123,7 @@ async def generate_speech_endpoint(
err(traceback.format_exc()) err(traceback.format_exc())
raise HTTPException(status_code=666, detail="error in TTS") raise HTTPException(status_code=666, detail="error in TTS")
async def generate_speech( async def generate_speech(
bg_tasks: BackgroundTasks, bg_tasks: BackgroundTasks,
text: str, text: str,
@ -134,6 +135,14 @@ async def generate_speech(
title: str = None, title: str = None,
output_dir = None output_dir = None
) -> str: ) -> str:
L.debug(f"Entering generate_speech function")
L.debug(f"API.EXTENSIONS: {API.EXTENSIONS}")
L.debug(f"Type of API.EXTENSIONS: {type(API.EXTENSIONS)}")
L.debug(f"Dir of API.EXTENSIONS: {dir(API.EXTENSIONS)}")
L.debug(f"Tts config: {Tts}")
L.debug(f"Type of Tts: {type(Tts)}")
L.debug(f"Dir of Tts: {dir(Tts)}")
output_dir = Path(output_dir) if output_dir else TTS_OUTPUT_DIR output_dir = Path(output_dir) if output_dir else TTS_OUTPUT_DIR
if not output_dir.exists(): if not output_dir.exists():
output_dir.mkdir(parents=True) output_dir.mkdir(parents=True)
@ -143,47 +152,53 @@ async def generate_speech(
title = title if title else "TTS audio" title = title if title else "TTS audio"
output_path = output_dir / f"{dt_datetime.now().strftime('%Y%m%d%H%M%S')} {title}.wav" output_path = output_dir / f"{dt_datetime.now().strftime('%Y%m%d%H%M%S')} {title}.wav"
if model == "eleven_turbo_v2" and API.EXTENSIONS.elevenlabs: L.debug(f"Model: {model}")
info("Using ElevenLabs.") L.debug(f"API.EXTENSIONS.elevenlabs: {getattr(API.EXTENSIONS, 'elevenlabs', None)}")
L.debug(f"API.EXTENSIONS.xtts: {getattr(API.EXTENSIONS, 'xtts', None)}")
if model == "eleven_turbo_v2" and getattr(API.EXTENSIONS, 'elevenlabs', False):
L.info("Using ElevenLabs.")
audio_file_path = await elevenlabs_tts(model, text, voice, title, output_dir) audio_file_path = await elevenlabs_tts(model, text, voice, title, output_dir)
elif API.EXTENSIONS.xtts: elif getattr(API.EXTENSIONS, 'xtts', False):
info("Using XTTS2") L.info("Using XTTS2")
audio_file_path = await local_tts(text, speed, voice, voice_file, podcast, bg_tasks, title, output_path) audio_file_path = await local_tts(text, speed, voice, voice_file, podcast, bg_tasks, title, output_path)
else: else:
err(f"No TTS module enabled!") L.error(f"No TTS module enabled!")
return None raise ValueError("No TTS module enabled")
if not audio_file_path: if not audio_file_path:
raise ValueError("TTS generation failed: audio_file_path is empty or None") raise ValueError("TTS generation failed: audio_file_path is empty or None")
elif audio_file_path.exists(): elif audio_file_path.exists():
info(f"Saved to {audio_file_path}") L.info(f"Saved to {audio_file_path}")
else: else:
warn(f"No file exists at {audio_file_path}") L.warn(f"No file exists at {audio_file_path}")
if podcast: if podcast:
podcast_path = Path(Dir.PODCAST) / Path(audio_file_path).name podcast_path = Path(Dir.PODCAST) / Path(audio_file_path).name
shutil.copy(str(audio_file_path), str(podcast_path)) shutil.copy(str(audio_file_path), str(podcast_path))
if podcast_path.exists(): if podcast_path.exists():
info(f"Saved to podcast path: {podcast_path}") L.info(f"Saved to podcast path: {podcast_path}")
else: else:
warn(f"Podcast mode enabled, but failed to save to {podcast_path}") L.warn(f"Podcast mode enabled, but failed to save to {podcast_path}")
if podcast_path != audio_file_path: if podcast_path != audio_file_path:
info(f"Podcast mode enabled, so we will remove {audio_file_path}") L.info(f"Podcast mode enabled, so we will remove {audio_file_path}")
bg_tasks.add_task(os.remove, str(audio_file_path)) bg_tasks.add_task(os.remove, str(audio_file_path))
else: else:
warn(f"Podcast path set to same as audio file path...") L.warn(f"Podcast path set to same as audio file path...")
return str(podcast_path) return str(podcast_path)
return str(audio_file_path) return str(audio_file_path)
except Exception as e: except Exception as e:
err(f"Failed to generate speech: {str(e)}") L.error(f"Failed to generate speech: {str(e)}")
L.error(f"Traceback: {traceback.format_exc()}")
raise HTTPException(status_code=500, detail=f"Failed to generate speech: {str(e)}") raise HTTPException(status_code=500, detail=f"Failed to generate speech: {str(e)}")
async def get_model(voice: str = None, voice_file: UploadFile = None): async def get_model(voice: str = None, voice_file: UploadFile = None):
if (voice_file or (voice and await select_voice(voice))) and API.EXTENSIONS.xtts: if (voice_file or (voice and await select_voice(voice))) and API.EXTENSIONS.xtts:
return "xtts" return "xtts"