From a825ed23b2f37a2ba1aa3cca0fc727a08ee0b31a Mon Sep 17 00:00:00 2001
From: sanj <67624670+iodrift@users.noreply.github.com>
Date: Thu, 8 Aug 2024 15:30:58 -0700
Subject: [PATCH] Auto-update: Thu Aug  8 15:30:58 PDT 2024

---
 sijapi/routers/serve.py |   8 +-
 sijapi/routers/tts.py   | 217 +++++++++++++++++++++-------------------
 2 files changed, 116 insertions(+), 109 deletions(-)

diff --git a/sijapi/routers/serve.py b/sijapi/routers/serve.py
index abbb896..8404a34 100644
--- a/sijapi/routers/serve.py
+++ b/sijapi/routers/serve.py
@@ -125,11 +125,11 @@ async def hook_alert(request: Request):
 async def notify(alert: str):
     fail = True
     try:
-        if API.EXTENSIONS.shellfish == "on" or API.EXTENSIONS.shellfish == True:
+        if API.EXTENSIONS.shellfish == True:
             await notify_shellfish(alert)
             fail = False
 
-        if API.EXTENSIONS.macnotify == "on" or API.EXTENSIONS.macnotify == True:
+        if API.EXTENSIONS.macnotify == True:
             if TS_ID == MAC_ID:
                 await notify_local(alert)
                 fail = False
@@ -165,7 +165,7 @@ async def notify_remote(host: str, message: str, username: str = None, password:
     ssh.close()
 
 
-if API.EXTENSIONS.shellfish == "on" or API.EXTENSIONS.shellfish == True:
+if API.EXTENSIONS.shellfish == True:
     async def notify_shellfish(alert: str):
         key = "d7e810e7601cd296a05776c169b4fe97a6a5ee1fd46abe38de54f415732b3f4b"
         user = "WuqPwm1VpGijF4U5AnIKzqNMVWGioANTRjJoonPm"
@@ -250,7 +250,7 @@ if API.EXTENSIONS.shellfish == "on" or API.EXTENSIONS.shellfish == True:
         return result.stdout
 
 
-if API.EXTENSIONS.courtlistener == "on" or API.EXTENSIONS.courtlistener == True:
+if API.EXTENSIONS.courtlistener == True:
     with open(CASETABLE_PATH, 'r') as file:
         CASETABLE = json.load(file)
 
diff --git a/sijapi/routers/tts.py b/sijapi/routers/tts.py
index 83ca8ef..1f25fd8 100644
--- a/sijapi/routers/tts.py
+++ b/sijapi/routers/tts.py
@@ -14,7 +14,6 @@ import asyncio
 from pydantic import BaseModel
 from typing import Optional, Union, List
 from pydub import AudioSegment
-from TTS.api import TTS
 from pathlib import Path
 from datetime import datetime as dt_datetime
 from time import time
@@ -144,12 +143,15 @@ async def generate_speech(
         title = title if title else "TTS audio"
         output_path = output_dir / f"{dt_datetime.now().strftime('%Y%m%d%H%M%S')} {title}.wav"
         
-        if model == "eleven_turbo_v2":
+        if model == "eleven_turbo_v2" and API.EXTENSIONS.elevenlabs == True:
             info("Using ElevenLabs.")
             audio_file_path = await elevenlabs_tts(model, text, voice, title, output_dir)
-        else:  # if model == "xtts":
+        elif API.EXTENSIONS.xtts == True:
             info("Using XTTS2")
             audio_file_path = await local_tts(text, speed, voice, voice_file, podcast, bg_tasks, title, output_path)
+        else:
+            err(f"No TTS module enabled!")
+            return None
 
         if not audio_file_path:
             raise ValueError("TTS generation failed: audio_file_path is empty or None")
@@ -183,39 +185,25 @@ async def generate_speech(
 
 
 async def get_model(voice: str = None, voice_file: UploadFile = None):
-    if voice_file or (voice and await select_voice(voice)):
+    if (voice_file or (voice and await select_voice(voice))) and API.EXTENSIONS.xtts == True:
         return "xtts"
     
-    elif voice and await determine_voice_id(voice):
+    elif voice and await determine_voice_id(voice) and API.EXTENSIONS.elevenlabs == True:
         return "eleven_turbo_v2"
     
     else:
-        raise HTTPException(status_code=400, detail="No model or voice specified")
+        err(f"No model or voice specified, or no TTS module loaded")
+        raise HTTPException(status_code=400, detail="No model or voice specified, or no TTS module loaded")
 
 async def determine_voice_id(voice_name: str) -> str:
     debug(f"Searching for voice id for {voice_name}")
-    
-    # Todo: move this to tts.yaml
-    hardcoded_voices = {
-        "alloy": "E3A1KVbKoWSIKSZwSUsW",
-        "echo": "b42GBisbu9r5m5n6pHF7",
-        "fable": "KAX2Y6tTs0oDWq7zZXW7",
-        "onyx": "clQb8NxY08xZ6mX6wCPE",
-        "nova": "6TayTBKLMOsghG7jYuMX",
-        "shimmer": "E7soeOyjpmuZFurvoxZ2",
-        "Luna": "6TayTBKLMOsghG7jYuMX",
-        "Sangye": "E7soeOyjpmuZFurvoxZ2",
-        "Herzog": "KAX2Y6tTs0oDWq7zZXW7",
-        "Attenborough": "b42GBisbu9r5m5n6pHF7",
-        "Victoria": "7UBkHqZOtFRLq6cSMQQg"
-    }
 
     if voice_name in Tts.elevenlabs.voices:
         voice_id = hardcoded_voices[voice_name]
         debug(f"Found voice ID - {voice_id}")
         return voice_id
 
-    debug(f"Requested voice not among the hardcoded options.. checking with 11L next.")
+    debug(f"Requested voice not among the voices specified in config/tts.yaml. Checking with 11L next.")
     url = "https://api.elevenlabs.io/v1/voices"
     headers = {"xi-api-key": ELEVENLABS_API_KEY}
     async with httpx.AsyncClient() as client:
@@ -231,38 +219,42 @@ async def determine_voice_id(voice_name: str) -> str:
         except Exception as e:
             err(f"Error determining voice ID: {str(e)}")
 
-    # as a last fallback, rely on David Attenborough; move this to tts.yaml
-    return "b42GBisbu9r5m5n6pHF7"
+    warn(f"Voice \'{voice_name}\' not found; attempting to use the default specified in config/tts.yaml: {Tts.elevenlabs.default}")
+    return await determine_voice_id(Tts.elevenlabs.default)
 
 
 async def elevenlabs_tts(model: str, input_text: str, voice: str, title: str = None, output_dir: str = None):
 
-    voice_id = await determine_voice_id(voice)
-
-    url = f"https://api.elevenlabs.io/v1/text-to-speech/{voice_id}"
-    payload = {
-        "text": input_text,
-        "model_id": model
-    }
-    headers = {"Content-Type": "application/json", "xi-api-key": ELEVENLABS_API_KEY}
-    try:
-        async with httpx.AsyncClient(timeout=httpx.Timeout(300.0)) as client:  # 5 minutes timeout
-            response = await client.post(url, json=payload, headers=headers)
-            output_dir = output_dir if output_dir else TTS_OUTPUT_DIR
-            title = title if title else dt_datetime.now().strftime("%Y%m%d%H%M%S")
-            filename = f"{sanitize_filename(title)}.mp3"
-            file_path = Path(output_dir) / filename
-            if response.status_code == 200:            
-                with open(file_path, "wb") as audio_file:
-                    audio_file.write(response.content)
-                # info(f"file_path: {file_path}")
-                return file_path
-            else:
-                raise HTTPException(status_code=response.status_code, detail="Error from ElevenLabs API")
-            
-    except Exception as e:
-        err(f"Error from Elevenlabs API: {e}")
-        raise HTTPException(status_code=response.status_code, detail="Error from ElevenLabs API")
+    if API.EXTENSIONS.elevenlabs == True:
+        voice_id = await determine_voice_id(voice)
+    
+        url = f"https://api.elevenlabs.io/v1/text-to-speech/{voice_id}"
+        payload = {
+            "text": input_text,
+            "model_id": model
+        }
+        headers = {"Content-Type": "application/json", "xi-api-key": ELEVENLABS_API_KEY}
+        try:
+            async with httpx.AsyncClient(timeout=httpx.Timeout(300.0)) as client:  # 5 minutes timeout
+                response = await client.post(url, json=payload, headers=headers)
+                output_dir = output_dir if output_dir else TTS_OUTPUT_DIR
+                title = title if title else dt_datetime.now().strftime("%Y%m%d%H%M%S")
+                filename = f"{sanitize_filename(title)}.mp3"
+                file_path = Path(output_dir) / filename
+                if response.status_code == 200:            
+                    with open(file_path, "wb") as audio_file:
+                        audio_file.write(response.content)
+                    # info(f"file_path: {file_path}")
+                    return file_path
+                else:
+                    raise HTTPException(status_code=response.status_code, detail="Error from ElevenLabs API")
+                
+        except Exception as e:
+            err(f"Error from Elevenlabs API: {e}")
+            raise HTTPException(status_code=response.status_code, detail="Error from ElevenLabs API")
+    
+    else:
+        warn(f"elevenlabs_tts called but ElevenLabs module disabled in config/api.yaml!")
 
 
 
@@ -324,56 +316,64 @@ async def local_tts(
     title: str = None,
     output_path: Optional[Path] = None
 ) -> str:
-    if output_path:
-        file_path = Path(output_path)
+
+    if API.EXTENSIONS.xtts == True:
+        from TTS.api import TTS
+        
+        if output_path:
+            file_path = Path(output_path)
+        else:
+            datetime_str = dt_datetime.now().strftime("%Y%m%d%H%M%S")
+            title = sanitize_filename(title) if title else "Audio"
+            filename = f"{datetime_str}_{title}.wav"
+            file_path = TTS_OUTPUT_DIR / filename
+    
+        # Ensure the parent directory exists
+        file_path.parent.mkdir(parents=True, exist_ok=True)
+    
+        voice_file_path = await get_voice_file_path(voice, voice_file)
+        
+        # Initialize TTS model in a separate thread
+        XTTS = await asyncio.to_thread(TTS, model_name=Tts.xtts.model)
+        await asyncio.to_thread(XTTS.to, DEVICE)
+    
+        segments = split_text(text_content)
+        combined_audio = AudioSegment.silent(duration=0)
+    
+        for i, segment in enumerate(segments):
+            segment_file_path = TTS_SEGMENTS_DIR / f"segment_{i}.wav"
+            debug(f"Segment file path: {segment_file_path}")
+            
+            # Run TTS in a separate thread
+            await asyncio.to_thread(
+                XTTS.tts_to_file,
+                text=segment,
+                speed=speed,
+                file_path=str(segment_file_path),
+                speaker_wav=[voice_file_path],
+                language="en"
+            )
+            debug(f"Segment file generated: {segment_file_path}")
+            
+            # Load and combine audio in a separate thread
+            segment_audio = await asyncio.to_thread(AudioSegment.from_wav, str(segment_file_path))
+            combined_audio += segment_audio
+    
+            # Delete the segment file
+            await asyncio.to_thread(segment_file_path.unlink)
+    
+        # Export the combined audio in a separate thread
+        if podcast:
+            podcast_file_path = Path(Dir.PODCAST) / file_path.name
+            await asyncio.to_thread(combined_audio.export, podcast_file_path, format="wav")
+        
+        await asyncio.to_thread(combined_audio.export, file_path, format="wav")
+    
+        return str(file_path)
+        
     else:
-        datetime_str = dt_datetime.now().strftime("%Y%m%d%H%M%S")
-        title = sanitize_filename(title) if title else "Audio"
-        filename = f"{datetime_str}_{title}.wav"
-        file_path = TTS_OUTPUT_DIR / filename
-
-    # Ensure the parent directory exists
-    file_path.parent.mkdir(parents=True, exist_ok=True)
-
-    voice_file_path = await get_voice_file_path(voice, voice_file)
-    
-    # Initialize TTS model in a separate thread
-    XTTS = await asyncio.to_thread(TTS, model_name=Tts.xtts.model)
-    await asyncio.to_thread(XTTS.to, DEVICE)
-
-    segments = split_text(text_content)
-    combined_audio = AudioSegment.silent(duration=0)
-
-    for i, segment in enumerate(segments):
-        segment_file_path = TTS_SEGMENTS_DIR / f"segment_{i}.wav"
-        debug(f"Segment file path: {segment_file_path}")
-        
-        # Run TTS in a separate thread
-        await asyncio.to_thread(
-            XTTS.tts_to_file,
-            text=segment,
-            speed=speed,
-            file_path=str(segment_file_path),
-            speaker_wav=[voice_file_path],
-            language="en"
-        )
-        debug(f"Segment file generated: {segment_file_path}")
-        
-        # Load and combine audio in a separate thread
-        segment_audio = await asyncio.to_thread(AudioSegment.from_wav, str(segment_file_path))
-        combined_audio += segment_audio
-
-        # Delete the segment file
-        await asyncio.to_thread(segment_file_path.unlink)
-
-    # Export the combined audio in a separate thread
-    if podcast:
-        podcast_file_path = Path(Dir.PODCAST) / file_path.name
-        await asyncio.to_thread(combined_audio.export, podcast_file_path, format="wav")
-    
-    await asyncio.to_thread(combined_audio.export, file_path, format="wav")
-
-    return str(file_path)
+        warn(f"local_tts called but xtts module disabled!")
+        return None
 
 
 
@@ -394,12 +394,20 @@ async def stream_tts(text_content: str, speed: float, voice: str, voice_file) ->
 
 
 async def generate_tts(text: str, speed: float, voice_file_path: str) -> str:
-    output_dir = tempfile.mktemp(suffix=".wav", dir=tempfile.gettempdir())
+    
+    if API.EXTENSIONS.xtts == True:
+        from TTS.api import TTS
+        
+        output_dir = tempfile.mktemp(suffix=".wav", dir=tempfile.gettempdir())
+    
+        XTTS = TTS(model_name=Tts.xtts.model).to(DEVICE)
+        XTTS.tts_to_file(text=text, speed=speed, file_path=output_dir, speaker_wav=[voice_file_path], language="en")
 
-    XTTS = TTS(model_name=Tts.xtts.model).to(DEVICE)
-    XTTS.tts_to_file(text=text, speed=speed, file_path=output_dir, speaker_wav=[voice_file_path], language="en")
-
-    return output_dir
+        return output_dir
+        
+    else:
+        warn(f"generate_tts called but xtts module disabled!")
+        return None
 
 
 async def get_audio_stream(model: str, input_text: str, voice: str):
@@ -455,7 +463,6 @@ def clean_text_for_tts(text: str) -> str:
         debug(f"No text received.")
 
 
-
 def copy_to_podcast_dir(file_path):
     try:
         # Extract the file name from the file path