Fixed asr and finally made it fully async

This commit is contained in:
sanj 2024-06-27 19:43:27 -07:00
parent 6acb9e4d8e
commit 7290bf4295
2 changed files with 37 additions and 20 deletions

View file

@ -41,7 +41,6 @@ transcription_results = {}
@asr.post("/transcribe")
@asr.post("/v1/audio/transcription")
async def transcribe_endpoint(
bg_tasks: BackgroundTasks,
file: UploadFile = File(...),
params: str = Form(...)
):
@ -58,7 +57,7 @@ async def transcribe_endpoint(
temp_file.write(await file.read())
temp_file_path = temp_file.name
transcription_job = await transcribe_audio(file_path=temp_file_path, params=parameters, bg_tasks=bg_tasks)
transcription_job = await transcribe_audio(file_path=temp_file_path, params=parameters)
job_id = transcription_job["job_id"]
# Poll for completion
@ -80,7 +79,7 @@ async def transcribe_endpoint(
# If we've reached this point, the transcription has taken too long
return JSONResponse(content={"status": "timeout", "message": "Transcription is taking longer than expected. Please check back later."}, status_code=202)
async def transcribe_audio(file_path, params: TranscribeParams, bg_tasks: BackgroundTasks):
async def transcribe_audio(file_path, params: TranscribeParams):
L.DEBUG(f"Transcribing audio file from {file_path}...")
file_path = await convert_to_wav(file_path)
model = params.model if params.model in WHISPER_CPP_MODELS else 'small'
@ -90,6 +89,7 @@ async def transcribe_audio(file_path, params: TranscribeParams, bg_tasks: Backgr
command.extend(['-t', str(max(1, min(params.threads or MAX_CPU_CORES, MAX_CPU_CORES)))])
command.extend(['-np']) # Always enable no-prints
if params.split_on_word:
command.append('-sow')
if params.temperature > 0:
@ -122,38 +122,56 @@ async def transcribe_audio(file_path, params: TranscribeParams, bg_tasks: Backgr
# Create a unique ID for this transcription job
job_id = str(uuid.uuid4())
L.DEBUG(f"Created job ID: {job_id}")
# Store the job status
transcription_results[job_id] = {"status": "processing", "result": None}
# Run the transcription in a background task
bg_tasks.add_task(process_transcription, command, file_path, job_id)
# Start the transcription process immediately
transcription_task = asyncio.create_task(process_transcription(command, file_path, job_id))
max_wait_time = 300 # 5 minutes
poll_interval = 1 # 1 second
start_time = asyncio.get_event_loop().time()
while asyncio.get_event_loop().time() - start_time < max_wait_time:
job_status = transcription_results.get(job_id, {})
if job_status["status"] == "completed":
return job_status["result"]
elif job_status["status"] == "failed":
raise Exception(f"Transcription failed: {job_status.get('error', 'Unknown error')}")
await asyncio.sleep(poll_interval)
L.DEBUG(f"Starting to poll for job {job_id}")
try:
while asyncio.get_event_loop().time() - start_time < max_wait_time:
job_status = transcription_results.get(job_id, {})
L.DEBUG(f"Current status for job {job_id}: {job_status['status']}")
if job_status["status"] == "completed":
L.INFO(f"Transcription completed for job {job_id}")
return job_status["result"]
elif job_status["status"] == "failed":
L.ERR(f"Transcription failed for job {job_id}: {job_status.get('error', 'Unknown error')}")
raise Exception(f"Transcription failed: {job_status.get('error', 'Unknown error')}")
await asyncio.sleep(poll_interval)
raise TimeoutError("Transcription timed out")
L.ERR(f"Transcription timed out for job {job_id}")
raise TimeoutError("Transcription timed out")
finally:
# Ensure the task is cancelled if we exit the loop
transcription_task.cancel()
# This line should never be reached, but just in case:
raise Exception("Unexpected exit from transcription function")
async def process_transcription(command, file_path, job_id):
try:
L.DEBUG(f"Starting transcription process for job {job_id}")
result = await run_transcription(command, file_path)
transcription_results[job_id] = {"status": "completed", "result": result}
L.DEBUG(f"Transcription completed for job {job_id}")
except Exception as e:
L.ERR(f"Transcription failed for job {job_id}: {str(e)}")
transcription_results[job_id] = {"status": "failed", "error": str(e)}
finally:
# Clean up the temporary file
os.remove(file_path)
L.DEBUG(f"Cleaned up temporary file for job {job_id}")
async def run_transcription(command, file_path):
L.DEBUG(f"Running transcription command: {' '.join(command)}")
proc = await asyncio.create_subprocess_exec(
*command,
stdout=asyncio.subprocess.PIPE,
@ -161,7 +179,10 @@ async def run_transcription(command, file_path):
)
stdout, stderr = await proc.communicate()
if proc.returncode != 0:
raise Exception(f"Error running command: {stderr.decode()}")
error_message = f"Error running command: {stderr.decode()}"
L.ERR(error_message)
raise Exception(error_message)
L.DEBUG("Transcription command completed successfully")
return stdout.decode().strip()
async def convert_to_wav(file_path: str):

View file

@ -134,7 +134,6 @@ async def build_daily_timeslips(date):
@note.post("/clip")
async def clip_post(
bg_tasks: BackgroundTasks,
file: UploadFile = None,
url: Optional[str] = Form(None),
source: Optional[str] = Form(None),
title: Optional[str] = Form(None),
@ -147,14 +146,12 @@ async def clip_post(
@note.post("/archive")
async def archive_post(
bg_tasks: BackgroundTasks,
file: UploadFile = None,
url: Optional[str] = Form(None),
source: Optional[str] = Form(None),
title: Optional[str] = Form(None),
encoding: str = Form('utf-8')
):
markdown_filename = await process_archive(bg_tasks, url, title, encoding, source)
markdown_filename = await process_archive(url, title, encoding, source)
return {"message": "Clip saved successfully", "markdown_filename": markdown_filename}
@note.get("/clip")
@ -199,7 +196,7 @@ async def process_for_daily_note(file: Optional[UploadFile] = File(None), text:
L.DEBUG(f"Processing {f.name}...")
if 'audio' in file_type:
transcription = await asr.transcribe_audio(file_path=absolute_path, params=asr.TranscribeParams(model="small-en", language="en", threads=6), bg_tasks=bg_tasks)
transcription = await asr.transcribe_audio(file_path=absolute_path, params=asr.TranscribeParams(model="small-en", language="en", threads=6))
file_entry = f"![[{relative_path}]]"
elif 'image' in file_type:
@ -500,7 +497,6 @@ async def html_to_markdown(url: str = None, source: str = None) -> Optional[str]
async def process_archive(
bg_tasks: BackgroundTasks,
url: str,
title: Optional[str] = None,
encoding: str = 'utf-8',