mirror of
https://github.com/khoj-ai/khoj.git
synced 2025-02-17 08:04:21 +00:00
Modularize chat models initialization with a reusable function
The chat model initialize interaction flow is fairly similar across the chat model providers. This should simplify adding new chat model providers and reduce chances of bugs in the interactive chat model initialization flow.
This commit is contained in:
parent
26c39576df
commit
2033f5168e
1 changed files with 100 additions and 126 deletions
|
@ -1,5 +1,6 @@
|
|||
import logging
|
||||
import os
|
||||
from typing import Tuple
|
||||
|
||||
from khoj.database.adapters import ConversationAdapters
|
||||
from khoj.database.models import (
|
||||
|
@ -41,41 +42,18 @@ def initialization(interactive: bool = True):
|
|||
"🗣️ Configure chat models available to your server. You can always update these at /server/admin using your admin account"
|
||||
)
|
||||
|
||||
# Set up OpenAI's online models
|
||||
default_openai_api_key = os.getenv("OPENAI_API_KEY")
|
||||
default_use_openai_model = {True: "y", False: "n"}[default_openai_api_key != None]
|
||||
use_model_provider = default_use_openai_model if not interactive else input("Add OpenAI models? (y/n): ")
|
||||
if use_model_provider == "y":
|
||||
logger.info("️💬 Setting up your OpenAI configuration")
|
||||
if interactive:
|
||||
user_api_key = input(f"Enter your OpenAI API key (default: {default_openai_api_key}): ")
|
||||
api_key = user_api_key if user_api_key != "" else default_openai_api_key
|
||||
else:
|
||||
api_key = default_openai_api_key
|
||||
chat_model_provider = OpenAIProcessorConversationConfig.objects.create(api_key=api_key, name="OpenAI")
|
||||
# Set up OpenAI's online chat models
|
||||
openai_configured, openai_provider = _setup_chat_model_provider(
|
||||
ChatModelOptions.ModelType.OPENAI,
|
||||
default_openai_chat_models,
|
||||
default_api_key=os.getenv("OPENAI_API_KEY"),
|
||||
vision_enabled=True,
|
||||
is_offline=False,
|
||||
interactive=interactive,
|
||||
)
|
||||
|
||||
if interactive:
|
||||
chat_model_names = input(
|
||||
f"Enter the OpenAI chat models you want to use (default: {','.join(default_openai_chat_models)}): "
|
||||
)
|
||||
chat_models = chat_model_names.split(",") if chat_model_names != "" else default_openai_chat_models
|
||||
chat_models = [model.strip() for model in chat_models]
|
||||
else:
|
||||
chat_models = default_openai_chat_models
|
||||
|
||||
# Add OpenAI chat models
|
||||
for chat_model in chat_models:
|
||||
vision_enabled = chat_model in ["gpt-4o-mini", "gpt-4o"]
|
||||
default_max_tokens = model_to_prompt_size.get(chat_model)
|
||||
ChatModelOptions.objects.create(
|
||||
chat_model=chat_model,
|
||||
model_type=ChatModelOptions.ModelType.OPENAI,
|
||||
max_prompt_size=default_max_tokens,
|
||||
openai_config=chat_model_provider,
|
||||
vision_enabled=vision_enabled,
|
||||
)
|
||||
|
||||
# Add OpenAI speech to text model
|
||||
# Setup OpenAI speech to text model
|
||||
if openai_configured:
|
||||
default_speech2text_model = "whisper-1"
|
||||
if interactive:
|
||||
openai_speech2text_model = input(
|
||||
|
@ -88,7 +66,8 @@ def initialization(interactive: bool = True):
|
|||
model_name=openai_speech2text_model, model_type=SpeechToTextModelOptions.ModelType.OPENAI
|
||||
)
|
||||
|
||||
# Add OpenAI text to image model
|
||||
# Setup OpenAI text to image model
|
||||
if openai_configured:
|
||||
default_text_to_image_model = "dall-e-3"
|
||||
if interactive:
|
||||
openai_text_to_image_model = input(
|
||||
|
@ -98,107 +77,44 @@ def initialization(interactive: bool = True):
|
|||
else:
|
||||
openai_text_to_image_model = default_text_to_image_model
|
||||
TextToImageModelConfig.objects.create(
|
||||
model_name=openai_text_to_image_model, model_type=TextToImageModelConfig.ModelType.OPENAI
|
||||
model_name=openai_text_to_image_model,
|
||||
model_type=TextToImageModelConfig.ModelType.OPENAI,
|
||||
openai_config=openai_provider,
|
||||
)
|
||||
|
||||
# Set up Google's Gemini online chat models
|
||||
default_gemini_api_key = os.getenv("GEMINI_API_KEY")
|
||||
default_use_gemini_model = {True: "y", False: "n"}[default_gemini_api_key != None]
|
||||
use_model_provider = default_use_gemini_model if not interactive else input("Add Google's chat models? (y/n): ")
|
||||
if use_model_provider == "y":
|
||||
logger.info("️💬 Setting up your Google Gemini configuration")
|
||||
if interactive:
|
||||
user_api_key = input(f"Enter your Gemini API key (default: {default_gemini_api_key}): ")
|
||||
api_key = user_api_key if user_api_key != "" else default_gemini_api_key
|
||||
else:
|
||||
api_key = default_gemini_api_key
|
||||
chat_model_provider = OpenAIProcessorConversationConfig.objects.create(api_key=api_key, name="Gemini")
|
||||
|
||||
if interactive:
|
||||
chat_model_names = input(
|
||||
f"Enter the Gemini chat models you want to use (default: {','.join(default_gemini_chat_models)}): "
|
||||
)
|
||||
chat_models = chat_model_names.split(",") if chat_model_names != "" else default_gemini_chat_models
|
||||
chat_models = [model.strip() for model in chat_models]
|
||||
else:
|
||||
chat_models = default_gemini_chat_models
|
||||
|
||||
# Add Gemini chat models
|
||||
for chat_model in chat_models:
|
||||
default_max_tokens = model_to_prompt_size.get(chat_model)
|
||||
vision_enabled = False
|
||||
ChatModelOptions.objects.create(
|
||||
chat_model=chat_model,
|
||||
model_type=ChatModelOptions.ModelType.GOOGLE,
|
||||
max_prompt_size=default_max_tokens,
|
||||
openai_config=chat_model_provider,
|
||||
vision_enabled=False,
|
||||
)
|
||||
_setup_chat_model_provider(
|
||||
ChatModelOptions.ModelType.GOOGLE,
|
||||
default_gemini_chat_models,
|
||||
default_api_key=os.getenv("GEMINI_API_KEY"),
|
||||
vision_enabled=False,
|
||||
is_offline=False,
|
||||
interactive=interactive,
|
||||
provider_name="Google Gemini",
|
||||
)
|
||||
|
||||
# Set up Anthropic's online chat models
|
||||
default_anthropic_api_key = os.getenv("ANTHROPIC_API_KEY")
|
||||
default_use_anthropic_model = {True: "y", False: "n"}[default_anthropic_api_key != None]
|
||||
use_model_provider = (
|
||||
default_use_anthropic_model if not interactive else input("Add Anthropic's chat models? (y/n): ")
|
||||
_setup_chat_model_provider(
|
||||
ChatModelOptions.ModelType.ANTHROPIC,
|
||||
default_anthropic_chat_models,
|
||||
default_api_key=os.getenv("ANTHROPIC_API_KEY"),
|
||||
vision_enabled=False,
|
||||
is_offline=False,
|
||||
interactive=interactive,
|
||||
)
|
||||
if use_model_provider == "y":
|
||||
logger.info("️💬 Setting up your Anthropic configuration")
|
||||
if interactive:
|
||||
user_api_key = input(f"Enter your Anthropic API key (default: {default_anthropic_api_key}): ")
|
||||
api_key = user_api_key if user_api_key != "" else default_anthropic_api_key
|
||||
else:
|
||||
api_key = default_anthropic_api_key
|
||||
chat_model_provider = OpenAIProcessorConversationConfig.objects.create(api_key=api_key, name="Anthropic")
|
||||
|
||||
if interactive:
|
||||
chat_model_names = input(
|
||||
f"Enter the Anthropic chat models you want to use (default: {','.join(default_anthropic_chat_models)}): "
|
||||
)
|
||||
chat_models = chat_model_names.split(",") if chat_model_names != "" else default_anthropic_chat_models
|
||||
chat_models = [model.strip() for model in chat_models]
|
||||
else:
|
||||
chat_models = default_anthropic_chat_models
|
||||
|
||||
# Add Anthropic chat models
|
||||
for chat_model in chat_models:
|
||||
vision_enabled = False
|
||||
default_max_tokens = model_to_prompt_size.get(chat_model)
|
||||
ChatModelOptions.objects.create(
|
||||
chat_model=chat_model,
|
||||
model_type=ChatModelOptions.ModelType.ANTHROPIC,
|
||||
max_prompt_size=default_max_tokens,
|
||||
openai_config=chat_model_provider,
|
||||
vision_enabled=False,
|
||||
)
|
||||
|
||||
# Set up offline chat models
|
||||
use_model_provider = "y" if not interactive else input("Add Offline chat models? (y/n): ")
|
||||
if use_model_provider == "y":
|
||||
logger.info("️💬 Setting up Offline chat models")
|
||||
|
||||
if interactive:
|
||||
chat_model_names = input(
|
||||
f"Enter the offline chat models you want to use. See HuggingFace for available GGUF models (default: {','.join(default_offline_chat_models)}): "
|
||||
)
|
||||
chat_models = chat_model_names.split(",") if chat_model_names != "" else default_offline_chat_models
|
||||
chat_models = [model.strip() for model in chat_models]
|
||||
else:
|
||||
chat_models = default_offline_chat_models
|
||||
|
||||
# Add chat models
|
||||
for chat_model in chat_models:
|
||||
default_max_tokens = model_to_prompt_size.get(chat_model)
|
||||
default_tokenizer = model_to_tokenizer.get(chat_model)
|
||||
ChatModelOptions.objects.create(
|
||||
chat_model=chat_model,
|
||||
model_type=ChatModelOptions.ModelType.OFFLINE,
|
||||
max_prompt_size=default_max_tokens,
|
||||
tokenizer=default_tokenizer,
|
||||
)
|
||||
|
||||
chat_models_configured = ChatModelOptions.objects.count()
|
||||
_setup_chat_model_provider(
|
||||
ChatModelOptions.ModelType.OFFLINE,
|
||||
default_offline_chat_models,
|
||||
default_api_key=None,
|
||||
vision_enabled=False,
|
||||
is_offline=True,
|
||||
interactive=interactive,
|
||||
)
|
||||
|
||||
# Explicitly set default chat model
|
||||
chat_models_configured = ChatModelOptions.objects.count()
|
||||
if chat_models_configured > 0:
|
||||
default_chat_model_name = ChatModelOptions.objects.first().chat_model
|
||||
# If there are multiple chat models, ask the user to choose the default chat model
|
||||
|
@ -236,6 +152,64 @@ def initialization(interactive: bool = True):
|
|||
|
||||
logger.info(f"🗣️ Offline speech to text model configured to {offline_speech2text_model}")
|
||||
|
||||
def _setup_chat_model_provider(
|
||||
model_type: ChatModelOptions.ModelType,
|
||||
default_chat_models: list,
|
||||
default_api_key: str,
|
||||
interactive: bool,
|
||||
vision_enabled: bool = False,
|
||||
is_offline: bool = False,
|
||||
provider_name: str = None,
|
||||
) -> Tuple[bool, OpenAIProcessorConversationConfig]:
|
||||
supported_vision_models = ["gpt-4o-mini", "gpt-4o"]
|
||||
provider_name = provider_name or model_type.name.capitalize()
|
||||
default_use_model = {True: "y", False: "n"}[default_api_key is not None or is_offline]
|
||||
use_model_provider = (
|
||||
default_use_model if not interactive else input(f"Add {provider_name} chat models? (y/n): ")
|
||||
)
|
||||
|
||||
if use_model_provider != "y":
|
||||
return False, None
|
||||
|
||||
logger.info(f"️💬 Setting up your {provider_name} chat configuration")
|
||||
|
||||
chat_model_provider = None
|
||||
if not is_offline:
|
||||
if interactive:
|
||||
user_api_key = input(f"Enter your {provider_name} API key (default: {default_api_key}): ")
|
||||
api_key = user_api_key if user_api_key != "" else default_api_key
|
||||
else:
|
||||
api_key = default_api_key
|
||||
chat_model_provider = OpenAIProcessorConversationConfig.objects.create(api_key=api_key, name=provider_name)
|
||||
|
||||
if interactive:
|
||||
chat_model_names = input(
|
||||
f"Enter the {provider_name} chat models you want to use (default: {','.join(default_chat_models)}): "
|
||||
)
|
||||
chat_models = chat_model_names.split(",") if chat_model_names != "" else default_chat_models
|
||||
chat_models = [model.strip() for model in chat_models]
|
||||
else:
|
||||
chat_models = default_chat_models
|
||||
|
||||
for chat_model in chat_models:
|
||||
default_max_tokens = model_to_prompt_size.get(chat_model)
|
||||
default_tokenizer = model_to_tokenizer.get(chat_model)
|
||||
vision_enabled = vision_enabled and chat_model in supported_vision_models
|
||||
|
||||
chat_model_options = {
|
||||
"chat_model": chat_model,
|
||||
"model_type": model_type,
|
||||
"max_prompt_size": default_max_tokens,
|
||||
"vision_enabled": vision_enabled,
|
||||
"tokenizer": default_tokenizer,
|
||||
"openai_config": chat_model_provider,
|
||||
}
|
||||
|
||||
ChatModelOptions.objects.create(**chat_model_options)
|
||||
|
||||
logger.info(f"🗣️ {provider_name} chat model configuration complete")
|
||||
return True, chat_model_provider
|
||||
|
||||
admin_user = KhojUser.objects.filter(is_staff=True).first()
|
||||
if admin_user is None:
|
||||
while True:
|
||||
|
|
Loading…
Add table
Reference in a new issue