Fixed asr and finally made it fully async

This commit is contained in:
sanj 2024-06-27 19:43:27 -07:00
parent 6acb9e4d8e
commit 7290bf4295
2 changed files with 37 additions and 20 deletions

View file

@ -41,7 +41,6 @@ transcription_results = {}
@asr.post("/transcribe") @asr.post("/transcribe")
@asr.post("/v1/audio/transcription") @asr.post("/v1/audio/transcription")
async def transcribe_endpoint( async def transcribe_endpoint(
bg_tasks: BackgroundTasks,
file: UploadFile = File(...), file: UploadFile = File(...),
params: str = Form(...) params: str = Form(...)
): ):
@ -58,7 +57,7 @@ async def transcribe_endpoint(
temp_file.write(await file.read()) temp_file.write(await file.read())
temp_file_path = temp_file.name temp_file_path = temp_file.name
transcription_job = await transcribe_audio(file_path=temp_file_path, params=parameters, bg_tasks=bg_tasks) transcription_job = await transcribe_audio(file_path=temp_file_path, params=parameters)
job_id = transcription_job["job_id"] job_id = transcription_job["job_id"]
# Poll for completion # Poll for completion
@ -80,7 +79,7 @@ async def transcribe_endpoint(
# If we've reached this point, the transcription has taken too long # If we've reached this point, the transcription has taken too long
return JSONResponse(content={"status": "timeout", "message": "Transcription is taking longer than expected. Please check back later."}, status_code=202) return JSONResponse(content={"status": "timeout", "message": "Transcription is taking longer than expected. Please check back later."}, status_code=202)
async def transcribe_audio(file_path, params: TranscribeParams, bg_tasks: BackgroundTasks): async def transcribe_audio(file_path, params: TranscribeParams):
L.DEBUG(f"Transcribing audio file from {file_path}...") L.DEBUG(f"Transcribing audio file from {file_path}...")
file_path = await convert_to_wav(file_path) file_path = await convert_to_wav(file_path)
model = params.model if params.model in WHISPER_CPP_MODELS else 'small' model = params.model if params.model in WHISPER_CPP_MODELS else 'small'
@ -90,6 +89,7 @@ async def transcribe_audio(file_path, params: TranscribeParams, bg_tasks: Backgr
command.extend(['-t', str(max(1, min(params.threads or MAX_CPU_CORES, MAX_CPU_CORES)))]) command.extend(['-t', str(max(1, min(params.threads or MAX_CPU_CORES, MAX_CPU_CORES)))])
command.extend(['-np']) # Always enable no-prints command.extend(['-np']) # Always enable no-prints
if params.split_on_word: if params.split_on_word:
command.append('-sow') command.append('-sow')
if params.temperature > 0: if params.temperature > 0:
@ -122,38 +122,56 @@ async def transcribe_audio(file_path, params: TranscribeParams, bg_tasks: Backgr
# Create a unique ID for this transcription job # Create a unique ID for this transcription job
job_id = str(uuid.uuid4()) job_id = str(uuid.uuid4())
L.DEBUG(f"Created job ID: {job_id}")
# Store the job status # Store the job status
transcription_results[job_id] = {"status": "processing", "result": None} transcription_results[job_id] = {"status": "processing", "result": None}
# Run the transcription in a background task # Start the transcription process immediately
bg_tasks.add_task(process_transcription, command, file_path, job_id) transcription_task = asyncio.create_task(process_transcription(command, file_path, job_id))
max_wait_time = 300 # 5 minutes max_wait_time = 300 # 5 minutes
poll_interval = 1 # 1 second poll_interval = 1 # 1 second
start_time = asyncio.get_event_loop().time() start_time = asyncio.get_event_loop().time()
while asyncio.get_event_loop().time() - start_time < max_wait_time: L.DEBUG(f"Starting to poll for job {job_id}")
job_status = transcription_results.get(job_id, {}) try:
if job_status["status"] == "completed": while asyncio.get_event_loop().time() - start_time < max_wait_time:
return job_status["result"] job_status = transcription_results.get(job_id, {})
elif job_status["status"] == "failed": L.DEBUG(f"Current status for job {job_id}: {job_status['status']}")
raise Exception(f"Transcription failed: {job_status.get('error', 'Unknown error')}") if job_status["status"] == "completed":
await asyncio.sleep(poll_interval) L.INFO(f"Transcription completed for job {job_id}")
return job_status["result"]
elif job_status["status"] == "failed":
L.ERR(f"Transcription failed for job {job_id}: {job_status.get('error', 'Unknown error')}")
raise Exception(f"Transcription failed: {job_status.get('error', 'Unknown error')}")
await asyncio.sleep(poll_interval)
raise TimeoutError("Transcription timed out") L.ERR(f"Transcription timed out for job {job_id}")
raise TimeoutError("Transcription timed out")
finally:
# Ensure the task is cancelled if we exit the loop
transcription_task.cancel()
# This line should never be reached, but just in case:
raise Exception("Unexpected exit from transcription function")
async def process_transcription(command, file_path, job_id): async def process_transcription(command, file_path, job_id):
try: try:
L.DEBUG(f"Starting transcription process for job {job_id}")
result = await run_transcription(command, file_path) result = await run_transcription(command, file_path)
transcription_results[job_id] = {"status": "completed", "result": result} transcription_results[job_id] = {"status": "completed", "result": result}
L.DEBUG(f"Transcription completed for job {job_id}")
except Exception as e: except Exception as e:
L.ERR(f"Transcription failed for job {job_id}: {str(e)}")
transcription_results[job_id] = {"status": "failed", "error": str(e)} transcription_results[job_id] = {"status": "failed", "error": str(e)}
finally: finally:
# Clean up the temporary file # Clean up the temporary file
os.remove(file_path) os.remove(file_path)
L.DEBUG(f"Cleaned up temporary file for job {job_id}")
async def run_transcription(command, file_path): async def run_transcription(command, file_path):
L.DEBUG(f"Running transcription command: {' '.join(command)}")
proc = await asyncio.create_subprocess_exec( proc = await asyncio.create_subprocess_exec(
*command, *command,
stdout=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE,
@ -161,7 +179,10 @@ async def run_transcription(command, file_path):
) )
stdout, stderr = await proc.communicate() stdout, stderr = await proc.communicate()
if proc.returncode != 0: if proc.returncode != 0:
raise Exception(f"Error running command: {stderr.decode()}") error_message = f"Error running command: {stderr.decode()}"
L.ERR(error_message)
raise Exception(error_message)
L.DEBUG("Transcription command completed successfully")
return stdout.decode().strip() return stdout.decode().strip()
async def convert_to_wav(file_path: str): async def convert_to_wav(file_path: str):

View file

@ -134,7 +134,6 @@ async def build_daily_timeslips(date):
@note.post("/clip") @note.post("/clip")
async def clip_post( async def clip_post(
bg_tasks: BackgroundTasks, bg_tasks: BackgroundTasks,
file: UploadFile = None,
url: Optional[str] = Form(None), url: Optional[str] = Form(None),
source: Optional[str] = Form(None), source: Optional[str] = Form(None),
title: Optional[str] = Form(None), title: Optional[str] = Form(None),
@ -147,14 +146,12 @@ async def clip_post(
@note.post("/archive") @note.post("/archive")
async def archive_post( async def archive_post(
bg_tasks: BackgroundTasks,
file: UploadFile = None,
url: Optional[str] = Form(None), url: Optional[str] = Form(None),
source: Optional[str] = Form(None), source: Optional[str] = Form(None),
title: Optional[str] = Form(None), title: Optional[str] = Form(None),
encoding: str = Form('utf-8') encoding: str = Form('utf-8')
): ):
markdown_filename = await process_archive(bg_tasks, url, title, encoding, source) markdown_filename = await process_archive(url, title, encoding, source)
return {"message": "Clip saved successfully", "markdown_filename": markdown_filename} return {"message": "Clip saved successfully", "markdown_filename": markdown_filename}
@note.get("/clip") @note.get("/clip")
@ -199,7 +196,7 @@ async def process_for_daily_note(file: Optional[UploadFile] = File(None), text:
L.DEBUG(f"Processing {f.name}...") L.DEBUG(f"Processing {f.name}...")
if 'audio' in file_type: if 'audio' in file_type:
transcription = await asr.transcribe_audio(file_path=absolute_path, params=asr.TranscribeParams(model="small-en", language="en", threads=6), bg_tasks=bg_tasks) transcription = await asr.transcribe_audio(file_path=absolute_path, params=asr.TranscribeParams(model="small-en", language="en", threads=6))
file_entry = f"![[{relative_path}]]" file_entry = f"![[{relative_path}]]"
elif 'image' in file_type: elif 'image' in file_type:
@ -500,7 +497,6 @@ async def html_to_markdown(url: str = None, source: str = None) -> Optional[str]
async def process_archive( async def process_archive(
bg_tasks: BackgroundTasks,
url: str, url: str,
title: Optional[str] = None, title: Optional[str] = None,
encoding: str = 'utf-8', encoding: str = 'utf-8',