From 7290bf4295fdc801c5d9440260be4a45b50fb94a Mon Sep 17 00:00:00 2001 From: sanj <67624670+iodrift@users.noreply.github.com> Date: Thu, 27 Jun 2024 19:43:27 -0700 Subject: [PATCH] Fixed asr and finally made it fully async --- sijapi/routers/asr.py | 49 ++++++++++++++++++++++++++++++------------ sijapi/routers/note.py | 8 ++----- 2 files changed, 37 insertions(+), 20 deletions(-) diff --git a/sijapi/routers/asr.py b/sijapi/routers/asr.py index f518164..adf2caf 100644 --- a/sijapi/routers/asr.py +++ b/sijapi/routers/asr.py @@ -41,7 +41,6 @@ transcription_results = {} @asr.post("/transcribe") @asr.post("/v1/audio/transcription") async def transcribe_endpoint( - bg_tasks: BackgroundTasks, file: UploadFile = File(...), params: str = Form(...) ): @@ -58,7 +57,7 @@ async def transcribe_endpoint( temp_file.write(await file.read()) temp_file_path = temp_file.name - transcription_job = await transcribe_audio(file_path=temp_file_path, params=parameters, bg_tasks=bg_tasks) + transcription_job = await transcribe_audio(file_path=temp_file_path, params=parameters) job_id = transcription_job["job_id"] # Poll for completion @@ -80,7 +79,7 @@ async def transcribe_endpoint( # If we've reached this point, the transcription has taken too long return JSONResponse(content={"status": "timeout", "message": "Transcription is taking longer than expected. Please check back later."}, status_code=202) -async def transcribe_audio(file_path, params: TranscribeParams, bg_tasks: BackgroundTasks): +async def transcribe_audio(file_path, params: TranscribeParams): L.DEBUG(f"Transcribing audio file from {file_path}...") file_path = await convert_to_wav(file_path) model = params.model if params.model in WHISPER_CPP_MODELS else 'small' @@ -90,6 +89,7 @@ async def transcribe_audio(file_path, params: TranscribeParams, bg_tasks: Backgr command.extend(['-t', str(max(1, min(params.threads or MAX_CPU_CORES, MAX_CPU_CORES)))]) command.extend(['-np']) # Always enable no-prints + if params.split_on_word: command.append('-sow') if params.temperature > 0: @@ -122,38 +122,56 @@ async def transcribe_audio(file_path, params: TranscribeParams, bg_tasks: Backgr # Create a unique ID for this transcription job job_id = str(uuid.uuid4()) + L.DEBUG(f"Created job ID: {job_id}") # Store the job status transcription_results[job_id] = {"status": "processing", "result": None} - # Run the transcription in a background task - bg_tasks.add_task(process_transcription, command, file_path, job_id) + # Start the transcription process immediately + transcription_task = asyncio.create_task(process_transcription(command, file_path, job_id)) max_wait_time = 300 # 5 minutes poll_interval = 1 # 1 second start_time = asyncio.get_event_loop().time() - while asyncio.get_event_loop().time() - start_time < max_wait_time: - job_status = transcription_results.get(job_id, {}) - if job_status["status"] == "completed": - return job_status["result"] - elif job_status["status"] == "failed": - raise Exception(f"Transcription failed: {job_status.get('error', 'Unknown error')}") - await asyncio.sleep(poll_interval) + L.DEBUG(f"Starting to poll for job {job_id}") + try: + while asyncio.get_event_loop().time() - start_time < max_wait_time: + job_status = transcription_results.get(job_id, {}) + L.DEBUG(f"Current status for job {job_id}: {job_status['status']}") + if job_status["status"] == "completed": + L.INFO(f"Transcription completed for job {job_id}") + return job_status["result"] + elif job_status["status"] == "failed": + L.ERR(f"Transcription failed for job {job_id}: {job_status.get('error', 'Unknown error')}") + raise Exception(f"Transcription failed: {job_status.get('error', 'Unknown error')}") + await asyncio.sleep(poll_interval) - raise TimeoutError("Transcription timed out") + L.ERR(f"Transcription timed out for job {job_id}") + raise TimeoutError("Transcription timed out") + finally: + # Ensure the task is cancelled if we exit the loop + transcription_task.cancel() + + # This line should never be reached, but just in case: + raise Exception("Unexpected exit from transcription function") async def process_transcription(command, file_path, job_id): try: + L.DEBUG(f"Starting transcription process for job {job_id}") result = await run_transcription(command, file_path) transcription_results[job_id] = {"status": "completed", "result": result} + L.DEBUG(f"Transcription completed for job {job_id}") except Exception as e: + L.ERR(f"Transcription failed for job {job_id}: {str(e)}") transcription_results[job_id] = {"status": "failed", "error": str(e)} finally: # Clean up the temporary file os.remove(file_path) + L.DEBUG(f"Cleaned up temporary file for job {job_id}") async def run_transcription(command, file_path): + L.DEBUG(f"Running transcription command: {' '.join(command)}") proc = await asyncio.create_subprocess_exec( *command, stdout=asyncio.subprocess.PIPE, @@ -161,7 +179,10 @@ async def run_transcription(command, file_path): ) stdout, stderr = await proc.communicate() if proc.returncode != 0: - raise Exception(f"Error running command: {stderr.decode()}") + error_message = f"Error running command: {stderr.decode()}" + L.ERR(error_message) + raise Exception(error_message) + L.DEBUG("Transcription command completed successfully") return stdout.decode().strip() async def convert_to_wav(file_path: str): diff --git a/sijapi/routers/note.py b/sijapi/routers/note.py index 005ad64..1aa81f0 100644 --- a/sijapi/routers/note.py +++ b/sijapi/routers/note.py @@ -134,7 +134,6 @@ async def build_daily_timeslips(date): @note.post("/clip") async def clip_post( bg_tasks: BackgroundTasks, - file: UploadFile = None, url: Optional[str] = Form(None), source: Optional[str] = Form(None), title: Optional[str] = Form(None), @@ -147,14 +146,12 @@ async def clip_post( @note.post("/archive") async def archive_post( - bg_tasks: BackgroundTasks, - file: UploadFile = None, url: Optional[str] = Form(None), source: Optional[str] = Form(None), title: Optional[str] = Form(None), encoding: str = Form('utf-8') ): - markdown_filename = await process_archive(bg_tasks, url, title, encoding, source) + markdown_filename = await process_archive(url, title, encoding, source) return {"message": "Clip saved successfully", "markdown_filename": markdown_filename} @note.get("/clip") @@ -199,7 +196,7 @@ async def process_for_daily_note(file: Optional[UploadFile] = File(None), text: L.DEBUG(f"Processing {f.name}...") if 'audio' in file_type: - transcription = await asr.transcribe_audio(file_path=absolute_path, params=asr.TranscribeParams(model="small-en", language="en", threads=6), bg_tasks=bg_tasks) + transcription = await asr.transcribe_audio(file_path=absolute_path, params=asr.TranscribeParams(model="small-en", language="en", threads=6)) file_entry = f"![[{relative_path}]]" elif 'image' in file_type: @@ -500,7 +497,6 @@ async def html_to_markdown(url: str = None, source: str = None) -> Optional[str] async def process_archive( - bg_tasks: BackgroundTasks, url: str, title: Optional[str] = None, encoding: str = 'utf-8',