mirror of
https://github.com/khoj-ai/khoj.git
synced 2025-02-17 08:04:21 +00:00
Transcribe speech to text offline with Whisper
- Allow server admin to configure offline speech to text model during initialization - Use offline speech to text model to transcribe audio from clients - Set offline whisper as default speech to text model as no setup api key reqd
This commit is contained in:
parent
a0a7ab7ec8
commit
4636390f7f
7 changed files with 52 additions and 12 deletions
|
@ -75,6 +75,7 @@ dependencies = [
|
||||||
"tzdata == 2023.3",
|
"tzdata == 2023.3",
|
||||||
"rapidocr-onnxruntime == 1.3.8",
|
"rapidocr-onnxruntime == 1.3.8",
|
||||||
"stripe == 7.3.0",
|
"stripe == 7.3.0",
|
||||||
|
"openai-whisper >= 20231117",
|
||||||
]
|
]
|
||||||
dynamic = ["version"]
|
dynamic = ["version"]
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
# Generated by Django 4.2.7 on 2023-11-26 09:37
|
# Generated by Django 4.2.7 on 2023-11-26 13:54
|
||||||
|
|
||||||
from django.db import migrations, models
|
from django.db import migrations, models
|
||||||
|
|
||||||
|
@ -15,11 +15,11 @@ class Migration(migrations.Migration):
|
||||||
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
|
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
|
||||||
("created_at", models.DateTimeField(auto_now_add=True)),
|
("created_at", models.DateTimeField(auto_now_add=True)),
|
||||||
("updated_at", models.DateTimeField(auto_now=True)),
|
("updated_at", models.DateTimeField(auto_now=True)),
|
||||||
("model_name", models.CharField(default="whisper-1", max_length=200)),
|
("model_name", models.CharField(default="base", max_length=200)),
|
||||||
(
|
(
|
||||||
"model_type",
|
"model_type",
|
||||||
models.CharField(
|
models.CharField(
|
||||||
choices=[("openai", "Openai"), ("offline", "Offline")], default="openai", max_length=200
|
choices=[("openai", "Openai"), ("offline", "Offline")], default="offline", max_length=200
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
|
|
|
@ -125,8 +125,8 @@ class SpeechToTextModelOptions(BaseModel):
|
||||||
OPENAI = "openai"
|
OPENAI = "openai"
|
||||||
OFFLINE = "offline"
|
OFFLINE = "offline"
|
||||||
|
|
||||||
model_name = models.CharField(max_length=200, default="whisper-1")
|
model_name = models.CharField(max_length=200, default="base")
|
||||||
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OPENAI)
|
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE)
|
||||||
|
|
||||||
|
|
||||||
class ChatModelOptions(BaseModel):
|
class ChatModelOptions(BaseModel):
|
||||||
|
|
17
src/khoj/processor/conversation/offline/whisper.py
Normal file
17
src/khoj/processor/conversation/offline/whisper.py
Normal file
|
@ -0,0 +1,17 @@
|
||||||
|
# External Packages
|
||||||
|
from asgiref.sync import sync_to_async
|
||||||
|
import whisper
|
||||||
|
|
||||||
|
# Internal Packages
|
||||||
|
from khoj.utils import state
|
||||||
|
|
||||||
|
|
||||||
|
async def transcribe_audio_offline(audio_filename: str, model: str) -> str | None:
|
||||||
|
"""
|
||||||
|
Transcribe audio file offline using Whisper
|
||||||
|
"""
|
||||||
|
# Send the audio data to the Whisper API
|
||||||
|
if not state.whisper_model:
|
||||||
|
state.whisper_model = whisper.load_model(model)
|
||||||
|
response = await sync_to_async(state.whisper_model.transcribe)(audio_filename)
|
||||||
|
return response["text"]
|
|
@ -31,6 +31,7 @@ from khoj.database.models import (
|
||||||
NotionConfig,
|
NotionConfig,
|
||||||
)
|
)
|
||||||
from khoj.processor.conversation.offline.chat_model import extract_questions_offline
|
from khoj.processor.conversation.offline.chat_model import extract_questions_offline
|
||||||
|
from khoj.processor.conversation.offline.whisper import transcribe_audio_offline
|
||||||
from khoj.processor.conversation.openai.gpt import extract_questions
|
from khoj.processor.conversation.openai.gpt import extract_questions
|
||||||
from khoj.processor.conversation.openai.whisper import transcribe_audio
|
from khoj.processor.conversation.openai.whisper import transcribe_audio
|
||||||
from khoj.processor.conversation.prompts import help_message, no_entries_found
|
from khoj.processor.conversation.prompts import help_message, no_entries_found
|
||||||
|
@ -605,13 +606,16 @@ async def transcribe(request: Request, common: CommonQueryParams, file: UploadFi
|
||||||
# Send the audio data to the Whisper API
|
# Send the audio data to the Whisper API
|
||||||
speech_to_text_config = await ConversationAdapters.get_speech_to_text_config()
|
speech_to_text_config = await ConversationAdapters.get_speech_to_text_config()
|
||||||
openai_chat_config = await ConversationAdapters.get_openai_chat_config()
|
openai_chat_config = await ConversationAdapters.get_openai_chat_config()
|
||||||
if not openai_chat_config or not speech_to_text_config:
|
if not speech_to_text_config:
|
||||||
# If the user has not configured a speech to text model, return an unprocessable entity error
|
# If the user has not configured a speech to text model, return an unprocessable entity error
|
||||||
status_code = 422
|
status_code = 422
|
||||||
elif speech_to_text_config.model_type == ChatModelOptions.ModelType.OPENAI:
|
elif openai_chat_config and speech_to_text_config.model_type == ChatModelOptions.ModelType.OPENAI:
|
||||||
api_key = openai_chat_config.api_key
|
api_key = openai_chat_config.api_key
|
||||||
speech2text_model = speech_to_text_config.model_name
|
speech2text_model = speech_to_text_config.model_name
|
||||||
user_message = await transcribe_audio(model=speech2text_model, audio_file=audio_file, api_key=api_key)
|
user_message = await transcribe_audio(audio_file, model=speech2text_model, api_key=api_key)
|
||||||
|
elif speech_to_text_config.model_type == ChatModelOptions.ModelType.OFFLINE:
|
||||||
|
speech2text_model = speech_to_text_config.model_name
|
||||||
|
user_message = await transcribe_audio_offline(audio_filename, model=speech2text_model)
|
||||||
finally:
|
finally:
|
||||||
# Close and Delete the temporary audio file
|
# Close and Delete the temporary audio file
|
||||||
audio_file.close()
|
audio_file.close()
|
||||||
|
|
|
@ -74,10 +74,9 @@ def initialization():
|
||||||
except ModuleNotFoundError as e:
|
except ModuleNotFoundError as e:
|
||||||
logger.warning("Offline models are not supported on this device.")
|
logger.warning("Offline models are not supported on this device.")
|
||||||
|
|
||||||
use_openai_model = input("Use OpenAI chat model? (y/n): ")
|
use_openai_model = input("Use OpenAI models? (y/n): ")
|
||||||
|
|
||||||
if use_openai_model == "y":
|
if use_openai_model == "y":
|
||||||
logger.info("🗣️ Setting up OpenAI chat model")
|
logger.info("🗣️ Setting up your OpenAI configuration")
|
||||||
api_key = input("Enter your OpenAI API key: ")
|
api_key = input("Enter your OpenAI API key: ")
|
||||||
OpenAIProcessorConversationConfig.objects.create(api_key=api_key)
|
OpenAIProcessorConversationConfig.objects.create(api_key=api_key)
|
||||||
|
|
||||||
|
@ -104,7 +103,25 @@ def initialization():
|
||||||
model_name=openai_speech2text_model, model_type=SpeechToTextModelOptions.ModelType.OPENAI
|
model_name=openai_speech2text_model, model_type=SpeechToTextModelOptions.ModelType.OPENAI
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info("🗣️ Chat model configuration complete")
|
if use_offline_model == "y" or use_openai_model == "y":
|
||||||
|
logger.info("🗣️ Chat model configuration complete")
|
||||||
|
|
||||||
|
use_offline_speech2text_model = input("Use offline speech to text model? (y/n): ")
|
||||||
|
if use_offline_speech2text_model == "y":
|
||||||
|
logger.info("🗣️ Setting up offline speech to text model")
|
||||||
|
# Delete any existing speech to text model options. There can only be one.
|
||||||
|
SpeechToTextModelOptions.objects.all().delete()
|
||||||
|
|
||||||
|
default_offline_speech2text_model = "base"
|
||||||
|
offline_speech2text_model = input(
|
||||||
|
f"Enter the Whisper model to use Offline (default: {default_offline_speech2text_model}): "
|
||||||
|
)
|
||||||
|
offline_speech2text_model = offline_speech2text_model or default_offline_speech2text_model
|
||||||
|
SpeechToTextModelOptions.objects.create(
|
||||||
|
model_name=offline_speech2text_model, model_type=SpeechToTextModelOptions.ModelType.OFFLINE
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"🗣️ Offline speech to text model configured to {offline_speech2text_model}")
|
||||||
|
|
||||||
admin_user = KhojUser.objects.filter(is_staff=True).first()
|
admin_user = KhojUser.objects.filter(is_staff=True).first()
|
||||||
if admin_user is None:
|
if admin_user is None:
|
||||||
|
|
|
@ -21,6 +21,7 @@ embeddings_model: EmbeddingsModel = None
|
||||||
cross_encoder_model: CrossEncoderModel = None
|
cross_encoder_model: CrossEncoderModel = None
|
||||||
content_index = ContentIndex()
|
content_index = ContentIndex()
|
||||||
gpt4all_processor_config: GPT4AllProcessorModel = None
|
gpt4all_processor_config: GPT4AllProcessorModel = None
|
||||||
|
whisper_model = None
|
||||||
config_file: Path = None
|
config_file: Path = None
|
||||||
verbose: int = 0
|
verbose: int = 0
|
||||||
host: str = None
|
host: str = None
|
||||||
|
|
Loading…
Add table
Reference in a new issue