mirror of
https://github.com/khoj-ai/khoj.git
synced 2025-02-17 08:04:21 +00:00
Create speech to text API endpoint. Use OpenAI whisper for ASR
- Wrap audio transcription in try/catch and delete audio file after processing - Use configured speech to text model, else handle error
This commit is contained in:
parent
1ca99b6eb0
commit
cc77bc4076
1 changed files with 52 additions and 1 deletions
|
@ -1,13 +1,16 @@
|
|||
# Standard Packages
|
||||
import concurrent.futures
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
import json
|
||||
from typing import Annotated, List, Optional, Union, Any
|
||||
import uuid
|
||||
|
||||
# External Packages
|
||||
from fastapi import APIRouter, Depends, HTTPException, Header, Request
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile, File
|
||||
import openai
|
||||
from starlette.authentication import requires
|
||||
from asgiref.sync import sync_to_async
|
||||
|
||||
|
@ -553,6 +556,54 @@ async def chat_options(
|
|||
return Response(content=json.dumps(cmd_options), media_type="application/json", status_code=200)
|
||||
|
||||
|
||||
@api.post("/speak")
|
||||
@requires(["authenticated"])
|
||||
async def transcribe_audio(request: Request, common: CommonQueryParams, file: UploadFile = File(...)):
|
||||
user: KhojUser = request.user.object
|
||||
audio_filename = f"{user.uuid}-{str(uuid.uuid4())}.webm"
|
||||
user_message: str = None
|
||||
|
||||
# Transcribe the audio from the request
|
||||
try:
|
||||
# Store the audio from the request in a temporary file
|
||||
audio_data = await file.read()
|
||||
with open(audio_filename, "wb") as audio_file_writer:
|
||||
audio_file_writer.write(audio_data)
|
||||
audio_file = open(audio_filename, "rb")
|
||||
|
||||
# Send the audio data to the Whisper API
|
||||
speech_to_text_config = await ConversationAdapters.get_speech_to_text_config()
|
||||
openai_chat_config = await ConversationAdapters.get_openai_chat_config()
|
||||
if not openai_chat_config or not speech_to_text_config:
|
||||
# If the user has not configured a speech to text model, return an unprocessable entity error
|
||||
status_code = 422
|
||||
elif speech_to_text_config.model_type == ChatModelOptions.ModelType.OPENAI:
|
||||
api_key = openai_chat_config.api_key
|
||||
speech2text_model = speech_to_text_config.model_name
|
||||
response = await sync_to_async(openai.Audio.translate)(
|
||||
model=speech2text_model, file=audio_file, api_key=api_key
|
||||
)
|
||||
user_message = response["text"]
|
||||
finally:
|
||||
# Close and Delete the temporary audio file
|
||||
audio_file.close()
|
||||
os.remove(audio_filename)
|
||||
|
||||
if user_message is None:
|
||||
return Response(status_code=status_code or 500)
|
||||
|
||||
update_telemetry_state(
|
||||
request=request,
|
||||
telemetry_type="api",
|
||||
api="speech_to_text",
|
||||
**common.__dict__,
|
||||
)
|
||||
|
||||
# Return the spoken text
|
||||
content = json.dumps({"text": user_message})
|
||||
return Response(content=content, media_type="application/json", status_code=200)
|
||||
|
||||
|
||||
@api.get("/chat", response_class=Response)
|
||||
@requires(["authenticated"])
|
||||
async def chat(
|
||||
|
|
Loading…
Add table
Reference in a new issue