mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-12-18 02:27:10 +00:00
Rename Chat Model Options table to Chat Model as short & readable (#1003)
- Previous was incorrectly plural but was defining only a single model - Rename chat model table field to name - Update documentation - Update references functions and variables to match new name
This commit is contained in:
parent
9be26e1bd2
commit
01bc6d35dc
26 changed files with 369 additions and 308 deletions
|
@ -25,7 +25,7 @@ Using LiteLLM with Khoj makes it possible to turn any LLM behind an API into you
|
|||
- Name: `proxy-name`
|
||||
- Api Key: `any string`
|
||||
- Api Base Url: **URL of your Openai Proxy API**
|
||||
4. Create a new [Chat Model Option](http://localhost:42110/server/admin/database/chatmodeloptions/add) on your Khoj admin panel.
|
||||
4. Create a new [Chat Model](http://localhost:42110/server/admin/database/chatmodel/add) on your Khoj admin panel.
|
||||
- Name: `llama3.1` (replace with the name of your local model)
|
||||
- Model Type: `Openai`
|
||||
- Openai Config: `<the proxy config you created in step 3>`
|
||||
|
|
|
@ -18,7 +18,7 @@ LM Studio can expose an [OpenAI API compatible server](https://lmstudio.ai/docs/
|
|||
- Name: `proxy-name`
|
||||
- Api Key: `any string`
|
||||
- Api Base Url: `http://localhost:1234/v1/` (default for LMStudio)
|
||||
4. Create a new [Chat Model Option](http://localhost:42110/server/admin/database/chatmodeloptions/add) on your Khoj admin panel.
|
||||
4. Create a new [Chat Model](http://localhost:42110/server/admin/database/chatmodel/add) on your Khoj admin panel.
|
||||
- Name: `llama3.1` (replace with the name of your local model)
|
||||
- Model Type: `Openai`
|
||||
- Openai Config: `<the proxy config you created in step 3>`
|
||||
|
|
|
@ -64,7 +64,7 @@ Restart your Khoj server after first run or update to the settings below to ensu
|
|||
- Name: `ollama`
|
||||
- Api Key: `any string`
|
||||
- Api Base Url: `http://localhost:11434/v1/` (default for Ollama)
|
||||
4. Create a new [Chat Model Option](http://localhost:42110/server/admin/database/chatmodeloptions/add) on your Khoj admin panel.
|
||||
4. Create a new [Chat Model](http://localhost:42110/server/admin/database/chatmodel/add) on your Khoj admin panel.
|
||||
- Name: `llama3.1` (replace with the name of your local model)
|
||||
- Model Type: `Openai`
|
||||
- Openai Config: `<the ollama config you created in step 3>`
|
||||
|
|
|
@ -25,7 +25,7 @@ For specific integrations, see our [Ollama](/advanced/ollama), [LMStudio](/advan
|
|||
- Name: `any name`
|
||||
- Api Key: `any string`
|
||||
- Api Base Url: **URL of your Openai Proxy API**
|
||||
3. Create a new [Chat Model Option](http://localhost:42110/server/admin/database/chatmodeloptions/add) on your Khoj admin panel.
|
||||
3. Create a new [Chat Model](http://localhost:42110/server/admin/database/chatmodel/add) on your Khoj admin panel.
|
||||
- Name: `llama3` (replace with the name of your local model)
|
||||
- Model Type: `Openai`
|
||||
- Openai Config: `<the proxy config you created in step 2>`
|
||||
|
|
|
@ -307,7 +307,7 @@ Using Ollama? See the [Ollama Integration](/advanced/ollama) section for more cu
|
|||
- Give the configuration a friendly name like `OpenAI`
|
||||
- (Optional) Set the API base URL. It is only relevant if you're using another OpenAI-compatible proxy server like [Ollama](/advanced/ollama) or [LMStudio](/advanced/lmstudio).<br />
|
||||
![example configuration for ai model api](/img/example_openai_processor_config.png)
|
||||
2. Create a new [chat model options](http://localhost:42110/server/admin/database/chatmodeloptions/add)
|
||||
2. Create a new [chat model](http://localhost:42110/server/admin/database/chatmodel/add)
|
||||
- Set the `chat-model` field to an [OpenAI chat model](https://platform.openai.com/docs/models). Example: `gpt-4o`.
|
||||
- Make sure to set the `model-type` field to `OpenAI`.
|
||||
- If your model supports vision, set the `vision enabled` field to `true`. This is currently only supported for OpenAI models with vision capabilities.
|
||||
|
@ -318,7 +318,7 @@ Using Ollama? See the [Ollama Integration](/advanced/ollama) section for more cu
|
|||
1. Create a new [AI Model API](http://localhost:42110/server/admin/database/aimodelapi/add) in the server admin settings.
|
||||
- Add your [Anthropic API key](https://console.anthropic.com/account/keys)
|
||||
- Give the configuration a friendly name like `Anthropic`. Do not configure the API base url.
|
||||
2. Create a new [chat model options](http://localhost:42110/server/admin/database/chatmodeloptions/add)
|
||||
2. Create a new [chat model](http://localhost:42110/server/admin/database/chatmodel/add)
|
||||
- Set the `chat-model` field to an [Anthropic chat model](https://docs.anthropic.com/en/docs/about-claude/models#model-names). Example: `claude-3-5-sonnet-20240620`.
|
||||
- Set the `model-type` field to `Anthropic`.
|
||||
- Set the `ai model api` field to the Anthropic AI Model API you created in step 1.
|
||||
|
@ -327,7 +327,7 @@ Using Ollama? See the [Ollama Integration](/advanced/ollama) section for more cu
|
|||
1. Create a new [AI Model API](http://localhost:42110/server/admin/database/aimodelapi/add) in the server admin settings.
|
||||
- Add your [Gemini API key](https://aistudio.google.com/app/apikey)
|
||||
- Give the configuration a friendly name like `Gemini`. Do not configure the API base url.
|
||||
2. Create a new [chat model options](http://localhost:42110/server/admin/database/chatmodeloptions/add)
|
||||
2. Create a new [chat model](http://localhost:42110/server/admin/database/chatmodel/add)
|
||||
- Set the `chat-model` field to a [Google Gemini chat model](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#gemini-models). Example: `gemini-1.5-flash`.
|
||||
- Set the `model-type` field to `Gemini`.
|
||||
- Set the `ai model api` field to the Gemini AI Model API you created in step 1.
|
||||
|
@ -343,7 +343,7 @@ Offline chat stays completely private and can work without internet using any op
|
|||
:::
|
||||
|
||||
1. Get the name of your preferred chat model from [HuggingFace](https://huggingface.co/models?pipeline_tag=text-generation&library=gguf). *Most GGUF format chat models are supported*.
|
||||
2. Open the [create chat model page](http://localhost:42110/server/admin/database/chatmodeloptions/add/) on the admin panel
|
||||
2. Open the [create chat model page](http://localhost:42110/server/admin/database/chatmodel/add/) on the admin panel
|
||||
3. Set the `chat-model` field to the name of your preferred chat model
|
||||
- Make sure the `model-type` is set to `Offline`
|
||||
4. Set the newly added chat model as your preferred model in your [User chat settings](http://localhost:42110/settings) and [Server chat settings](http://localhost:42110/server/admin/database/serverchatsettings/).
|
||||
|
|
|
@ -36,7 +36,7 @@ from torch import Tensor
|
|||
from khoj.database.models import (
|
||||
Agent,
|
||||
AiModelApi,
|
||||
ChatModelOptions,
|
||||
ChatModel,
|
||||
ClientApplication,
|
||||
Conversation,
|
||||
Entry,
|
||||
|
@ -736,8 +736,8 @@ class AgentAdapters:
|
|||
|
||||
@staticmethod
|
||||
def create_default_agent(user: KhojUser):
|
||||
default_conversation_config = ConversationAdapters.get_default_conversation_config(user)
|
||||
if default_conversation_config is None:
|
||||
default_chat_model = ConversationAdapters.get_default_chat_model(user)
|
||||
if default_chat_model is None:
|
||||
logger.info("No default conversation config found, skipping default agent creation")
|
||||
return None
|
||||
default_personality = prompts.personality.format(current_date="placeholder", day_of_week="placeholder")
|
||||
|
@ -746,7 +746,7 @@ class AgentAdapters:
|
|||
|
||||
if agent:
|
||||
agent.personality = default_personality
|
||||
agent.chat_model = default_conversation_config
|
||||
agent.chat_model = default_chat_model
|
||||
agent.slug = AgentAdapters.DEFAULT_AGENT_SLUG
|
||||
agent.name = AgentAdapters.DEFAULT_AGENT_NAME
|
||||
agent.privacy_level = Agent.PrivacyLevel.PUBLIC
|
||||
|
@ -760,7 +760,7 @@ class AgentAdapters:
|
|||
name=AgentAdapters.DEFAULT_AGENT_NAME,
|
||||
privacy_level=Agent.PrivacyLevel.PUBLIC,
|
||||
managed_by_admin=True,
|
||||
chat_model=default_conversation_config,
|
||||
chat_model=default_chat_model,
|
||||
personality=default_personality,
|
||||
slug=AgentAdapters.DEFAULT_AGENT_SLUG,
|
||||
)
|
||||
|
@ -787,7 +787,7 @@ class AgentAdapters:
|
|||
output_modes: List[str],
|
||||
slug: Optional[str] = None,
|
||||
):
|
||||
chat_model_option = await ChatModelOptions.objects.filter(chat_model=chat_model).afirst()
|
||||
chat_model_option = await ChatModel.objects.filter(name=chat_model).afirst()
|
||||
|
||||
# Slug will be None for new agents, which will trigger a new agent creation with a generated, immutable slug
|
||||
agent, created = await Agent.objects.filter(slug=slug, creator=user).aupdate_or_create(
|
||||
|
@ -972,29 +972,29 @@ class ConversationAdapters:
|
|||
|
||||
@staticmethod
|
||||
@require_valid_user
|
||||
def has_any_conversation_config(user: KhojUser):
|
||||
return ChatModelOptions.objects.filter(user=user).exists()
|
||||
def has_any_chat_model(user: KhojUser):
|
||||
return ChatModel.objects.filter(user=user).exists()
|
||||
|
||||
@staticmethod
|
||||
def get_all_conversation_configs():
|
||||
return ChatModelOptions.objects.all()
|
||||
def get_all_chat_models():
|
||||
return ChatModel.objects.all()
|
||||
|
||||
@staticmethod
|
||||
async def aget_all_conversation_configs():
|
||||
return await sync_to_async(list)(ChatModelOptions.objects.prefetch_related("ai_model_api").all())
|
||||
async def aget_all_chat_models():
|
||||
return await sync_to_async(list)(ChatModel.objects.prefetch_related("ai_model_api").all())
|
||||
|
||||
@staticmethod
|
||||
def get_vision_enabled_config():
|
||||
conversation_configurations = ConversationAdapters.get_all_conversation_configs()
|
||||
for config in conversation_configurations:
|
||||
chat_models = ConversationAdapters.get_all_chat_models()
|
||||
for config in chat_models:
|
||||
if config.vision_enabled:
|
||||
return config
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
async def aget_vision_enabled_config():
|
||||
conversation_configurations = await ConversationAdapters.aget_all_conversation_configs()
|
||||
for config in conversation_configurations:
|
||||
chat_models = await ConversationAdapters.aget_all_chat_models()
|
||||
for config in chat_models:
|
||||
if config.vision_enabled:
|
||||
return config
|
||||
return None
|
||||
|
@ -1010,7 +1010,7 @@ class ConversationAdapters:
|
|||
@staticmethod
|
||||
@arequire_valid_user
|
||||
async def aset_user_conversation_processor(user: KhojUser, conversation_processor_config_id: int):
|
||||
config = await ChatModelOptions.objects.filter(id=conversation_processor_config_id).afirst()
|
||||
config = await ChatModel.objects.filter(id=conversation_processor_config_id).afirst()
|
||||
if not config:
|
||||
return None
|
||||
new_config = await UserConversationConfig.objects.aupdate_or_create(user=user, defaults={"setting": config})
|
||||
|
@ -1026,24 +1026,24 @@ class ConversationAdapters:
|
|||
return new_config
|
||||
|
||||
@staticmethod
|
||||
def get_conversation_config(user: KhojUser):
|
||||
def get_chat_model(user: KhojUser):
|
||||
subscribed = is_user_subscribed(user)
|
||||
if not subscribed:
|
||||
return ConversationAdapters.get_default_conversation_config(user)
|
||||
return ConversationAdapters.get_default_chat_model(user)
|
||||
config = UserConversationConfig.objects.filter(user=user).first()
|
||||
if config:
|
||||
return config.setting
|
||||
return ConversationAdapters.get_advanced_conversation_config(user)
|
||||
return ConversationAdapters.get_advanced_chat_model(user)
|
||||
|
||||
@staticmethod
|
||||
async def aget_conversation_config(user: KhojUser):
|
||||
async def aget_chat_model(user: KhojUser):
|
||||
subscribed = await ais_user_subscribed(user)
|
||||
if not subscribed:
|
||||
return await ConversationAdapters.aget_default_conversation_config(user)
|
||||
return await ConversationAdapters.aget_default_chat_model(user)
|
||||
config = await UserConversationConfig.objects.filter(user=user).prefetch_related("setting").afirst()
|
||||
if config:
|
||||
return config.setting
|
||||
return ConversationAdapters.aget_advanced_conversation_config(user)
|
||||
return ConversationAdapters.aget_advanced_chat_model(user)
|
||||
|
||||
@staticmethod
|
||||
async def aget_voice_model_config(user: KhojUser) -> Optional[VoiceModelOption]:
|
||||
|
@ -1064,7 +1064,7 @@ class ConversationAdapters:
|
|||
return VoiceModelOption.objects.first()
|
||||
|
||||
@staticmethod
|
||||
def get_default_conversation_config(user: KhojUser = None):
|
||||
def get_default_chat_model(user: KhojUser = None):
|
||||
"""Get default conversation config. Prefer chat model by server admin > user > first created chat model"""
|
||||
# Get the server chat settings
|
||||
server_chat_settings = ServerChatSettings.objects.first()
|
||||
|
@ -1084,10 +1084,10 @@ class ConversationAdapters:
|
|||
return user_chat_settings.setting
|
||||
|
||||
# Get the first chat model if even the user chat settings are not set
|
||||
return ChatModelOptions.objects.filter().first()
|
||||
return ChatModel.objects.filter().first()
|
||||
|
||||
@staticmethod
|
||||
async def aget_default_conversation_config(user: KhojUser = None):
|
||||
async def aget_default_chat_model(user: KhojUser = None):
|
||||
"""Get default conversation config. Prefer chat model by server admin > user > first created chat model"""
|
||||
# Get the server chat settings
|
||||
server_chat_settings: ServerChatSettings = (
|
||||
|
@ -1117,17 +1117,17 @@ class ConversationAdapters:
|
|||
return user_chat_settings.setting
|
||||
|
||||
# Get the first chat model if even the user chat settings are not set
|
||||
return await ChatModelOptions.objects.filter().prefetch_related("ai_model_api").afirst()
|
||||
return await ChatModel.objects.filter().prefetch_related("ai_model_api").afirst()
|
||||
|
||||
@staticmethod
|
||||
def get_advanced_conversation_config(user: KhojUser):
|
||||
def get_advanced_chat_model(user: KhojUser):
|
||||
server_chat_settings = ServerChatSettings.objects.first()
|
||||
if server_chat_settings is not None and server_chat_settings.chat_advanced is not None:
|
||||
return server_chat_settings.chat_advanced
|
||||
return ConversationAdapters.get_default_conversation_config(user)
|
||||
return ConversationAdapters.get_default_chat_model(user)
|
||||
|
||||
@staticmethod
|
||||
async def aget_advanced_conversation_config(user: KhojUser = None):
|
||||
async def aget_advanced_chat_model(user: KhojUser = None):
|
||||
server_chat_settings: ServerChatSettings = (
|
||||
await ServerChatSettings.objects.filter()
|
||||
.prefetch_related("chat_advanced", "chat_advanced__ai_model_api")
|
||||
|
@ -1135,7 +1135,7 @@ class ConversationAdapters:
|
|||
)
|
||||
if server_chat_settings is not None and server_chat_settings.chat_advanced is not None:
|
||||
return server_chat_settings.chat_advanced
|
||||
return await ConversationAdapters.aget_default_conversation_config(user)
|
||||
return await ConversationAdapters.aget_default_chat_model(user)
|
||||
|
||||
@staticmethod
|
||||
async def aget_server_webscraper():
|
||||
|
@ -1247,16 +1247,16 @@ class ConversationAdapters:
|
|||
|
||||
@staticmethod
|
||||
def get_conversation_processor_options():
|
||||
return ChatModelOptions.objects.all()
|
||||
return ChatModel.objects.all()
|
||||
|
||||
@staticmethod
|
||||
def set_conversation_processor_config(user: KhojUser, new_config: ChatModelOptions):
|
||||
def set_user_chat_model(user: KhojUser, chat_model: ChatModel):
|
||||
user_conversation_config, _ = UserConversationConfig.objects.get_or_create(user=user)
|
||||
user_conversation_config.setting = new_config
|
||||
user_conversation_config.setting = chat_model
|
||||
user_conversation_config.save()
|
||||
|
||||
@staticmethod
|
||||
async def aget_user_conversation_config(user: KhojUser):
|
||||
async def aget_user_chat_model(user: KhojUser):
|
||||
config = (
|
||||
await UserConversationConfig.objects.filter(user=user).prefetch_related("setting__ai_model_api").afirst()
|
||||
)
|
||||
|
@ -1288,33 +1288,33 @@ class ConversationAdapters:
|
|||
return random.sample(all_questions, max_results)
|
||||
|
||||
@staticmethod
|
||||
def get_valid_conversation_config(user: KhojUser, conversation: Conversation):
|
||||
def get_valid_chat_model(user: KhojUser, conversation: Conversation):
|
||||
agent: Agent = conversation.agent if AgentAdapters.get_default_agent() != conversation.agent else None
|
||||
if agent and agent.chat_model:
|
||||
conversation_config = conversation.agent.chat_model
|
||||
chat_model = conversation.agent.chat_model
|
||||
else:
|
||||
conversation_config = ConversationAdapters.get_conversation_config(user)
|
||||
chat_model = ConversationAdapters.get_chat_model(user)
|
||||
|
||||
if conversation_config is None:
|
||||
conversation_config = ConversationAdapters.get_default_conversation_config()
|
||||
if chat_model is None:
|
||||
chat_model = ConversationAdapters.get_default_chat_model()
|
||||
|
||||
if conversation_config.model_type == ChatModelOptions.ModelType.OFFLINE:
|
||||
if chat_model.model_type == ChatModel.ModelType.OFFLINE:
|
||||
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
|
||||
chat_model = conversation_config.chat_model
|
||||
max_tokens = conversation_config.max_prompt_size
|
||||
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
|
||||
chat_model_name = chat_model.name
|
||||
max_tokens = chat_model.max_prompt_size
|
||||
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model_name, max_tokens)
|
||||
|
||||
return conversation_config
|
||||
return chat_model
|
||||
|
||||
if (
|
||||
conversation_config.model_type
|
||||
chat_model.model_type
|
||||
in [
|
||||
ChatModelOptions.ModelType.ANTHROPIC,
|
||||
ChatModelOptions.ModelType.OPENAI,
|
||||
ChatModelOptions.ModelType.GOOGLE,
|
||||
ChatModel.ModelType.ANTHROPIC,
|
||||
ChatModel.ModelType.OPENAI,
|
||||
ChatModel.ModelType.GOOGLE,
|
||||
]
|
||||
) and conversation_config.ai_model_api:
|
||||
return conversation_config
|
||||
) and chat_model.ai_model_api:
|
||||
return chat_model
|
||||
|
||||
else:
|
||||
raise ValueError("Invalid conversation config - either configure offline chat or openai chat")
|
||||
|
|
|
@ -16,7 +16,7 @@ from unfold import admin as unfold_admin
|
|||
from khoj.database.models import (
|
||||
Agent,
|
||||
AiModelApi,
|
||||
ChatModelOptions,
|
||||
ChatModel,
|
||||
ClientApplication,
|
||||
Conversation,
|
||||
Entry,
|
||||
|
@ -212,15 +212,15 @@ class KhojUserSubscription(unfold_admin.ModelAdmin):
|
|||
list_filter = ("type",)
|
||||
|
||||
|
||||
@admin.register(ChatModelOptions)
|
||||
class ChatModelOptionsAdmin(unfold_admin.ModelAdmin):
|
||||
@admin.register(ChatModel)
|
||||
class ChatModelAdmin(unfold_admin.ModelAdmin):
|
||||
list_display = (
|
||||
"id",
|
||||
"chat_model",
|
||||
"name",
|
||||
"ai_model_api",
|
||||
"max_prompt_size",
|
||||
)
|
||||
search_fields = ("id", "chat_model", "ai_model_api__name")
|
||||
search_fields = ("id", "name", "ai_model_api__name")
|
||||
|
||||
|
||||
@admin.register(TextToImageModelConfig)
|
||||
|
@ -385,7 +385,7 @@ class UserConversationConfigAdmin(unfold_admin.ModelAdmin):
|
|||
"get_chat_model",
|
||||
"get_subscription_type",
|
||||
)
|
||||
search_fields = ("id", "user__email", "setting__chat_model", "user__subscription__type")
|
||||
search_fields = ("id", "user__email", "setting__name", "user__subscription__type")
|
||||
ordering = ("-updated_at",)
|
||||
|
||||
def get_user_email(self, obj):
|
||||
|
@ -395,10 +395,10 @@ class UserConversationConfigAdmin(unfold_admin.ModelAdmin):
|
|||
get_user_email.admin_order_field = "user__email" # type: ignore
|
||||
|
||||
def get_chat_model(self, obj):
|
||||
return obj.setting.chat_model if obj.setting else None
|
||||
return obj.setting.name if obj.setting else None
|
||||
|
||||
get_chat_model.short_description = "Chat Model" # type: ignore
|
||||
get_chat_model.admin_order_field = "setting__chat_model" # type: ignore
|
||||
get_chat_model.admin_order_field = "setting__name" # type: ignore
|
||||
|
||||
def get_subscription_type(self, obj):
|
||||
if hasattr(obj.user, "subscription"):
|
||||
|
|
|
@ -0,0 +1,62 @@
|
|||
# Generated by Django 5.0.9 on 2024-12-09 04:21
|
||||
|
||||
import django.db.models.deletion
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("database", "0076_rename_openaiprocessorconversationconfig_aimodelapi_and_more"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.RenameModel(
|
||||
old_name="ChatModelOptions",
|
||||
new_name="ChatModel",
|
||||
),
|
||||
migrations.RenameField(
|
||||
model_name="chatmodel",
|
||||
old_name="chat_model",
|
||||
new_name="name",
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name="agent",
|
||||
name="chat_model",
|
||||
field=models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to="database.chatmodel"),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name="serverchatsettings",
|
||||
name="chat_advanced",
|
||||
field=models.ForeignKey(
|
||||
blank=True,
|
||||
default=None,
|
||||
null=True,
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
related_name="chat_advanced",
|
||||
to="database.chatmodel",
|
||||
),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name="serverchatsettings",
|
||||
name="chat_default",
|
||||
field=models.ForeignKey(
|
||||
blank=True,
|
||||
default=None,
|
||||
null=True,
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
related_name="chat_default",
|
||||
to="database.chatmodel",
|
||||
),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name="userconversationconfig",
|
||||
name="setting",
|
||||
field=models.ForeignKey(
|
||||
blank=True,
|
||||
default=None,
|
||||
null=True,
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
to="database.chatmodel",
|
||||
),
|
||||
),
|
||||
]
|
|
@ -193,7 +193,7 @@ class AiModelApi(DbBaseModel):
|
|||
return self.name
|
||||
|
||||
|
||||
class ChatModelOptions(DbBaseModel):
|
||||
class ChatModel(DbBaseModel):
|
||||
class ModelType(models.TextChoices):
|
||||
OPENAI = "openai"
|
||||
OFFLINE = "offline"
|
||||
|
@ -203,13 +203,13 @@ class ChatModelOptions(DbBaseModel):
|
|||
max_prompt_size = models.IntegerField(default=None, null=True, blank=True)
|
||||
subscribed_max_prompt_size = models.IntegerField(default=None, null=True, blank=True)
|
||||
tokenizer = models.CharField(max_length=200, default=None, null=True, blank=True)
|
||||
chat_model = models.CharField(max_length=200, default="bartowski/Meta-Llama-3.1-8B-Instruct-GGUF")
|
||||
name = models.CharField(max_length=200, default="bartowski/Meta-Llama-3.1-8B-Instruct-GGUF")
|
||||
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE)
|
||||
vision_enabled = models.BooleanField(default=False)
|
||||
ai_model_api = models.ForeignKey(AiModelApi, on_delete=models.CASCADE, default=None, null=True, blank=True)
|
||||
|
||||
def __str__(self):
|
||||
return self.chat_model
|
||||
return self.name
|
||||
|
||||
|
||||
class VoiceModelOption(DbBaseModel):
|
||||
|
@ -297,7 +297,7 @@ class Agent(DbBaseModel):
|
|||
models.CharField(max_length=200, choices=OutputModeOptions.choices), default=list, null=True, blank=True
|
||||
)
|
||||
managed_by_admin = models.BooleanField(default=False)
|
||||
chat_model = models.ForeignKey(ChatModelOptions, on_delete=models.CASCADE)
|
||||
chat_model = models.ForeignKey(ChatModel, on_delete=models.CASCADE)
|
||||
slug = models.CharField(max_length=200, unique=True)
|
||||
style_color = models.CharField(max_length=200, choices=StyleColorTypes.choices, default=StyleColorTypes.BLUE)
|
||||
style_icon = models.CharField(max_length=200, choices=StyleIconTypes.choices, default=StyleIconTypes.LIGHTBULB)
|
||||
|
@ -438,10 +438,10 @@ class WebScraper(DbBaseModel):
|
|||
|
||||
class ServerChatSettings(DbBaseModel):
|
||||
chat_default = models.ForeignKey(
|
||||
ChatModelOptions, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="chat_default"
|
||||
ChatModel, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="chat_default"
|
||||
)
|
||||
chat_advanced = models.ForeignKey(
|
||||
ChatModelOptions, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="chat_advanced"
|
||||
ChatModel, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="chat_advanced"
|
||||
)
|
||||
web_scraper = models.ForeignKey(
|
||||
WebScraper, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="web_scraper"
|
||||
|
@ -563,7 +563,7 @@ class SpeechToTextModelOptions(DbBaseModel):
|
|||
|
||||
class UserConversationConfig(DbBaseModel):
|
||||
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
|
||||
setting = models.ForeignKey(ChatModelOptions, on_delete=models.CASCADE, default=None, null=True, blank=True)
|
||||
setting = models.ForeignKey(ChatModel, on_delete=models.CASCADE, default=None, null=True, blank=True)
|
||||
|
||||
|
||||
class UserVoiceModelConfig(DbBaseModel):
|
||||
|
|
|
@ -60,7 +60,7 @@ import logging
|
|||
|
||||
from packaging import version
|
||||
|
||||
from khoj.database.models import AiModelApi, ChatModelOptions, SearchModelConfig
|
||||
from khoj.database.models import AiModelApi, ChatModel, SearchModelConfig
|
||||
from khoj.utils.yaml import load_config_from_file, save_config_to_file
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -98,11 +98,11 @@ def migrate_server_pg(args):
|
|||
|
||||
if "offline-chat" in raw_config["processor"]["conversation"]:
|
||||
offline_chat = raw_config["processor"]["conversation"]["offline-chat"]
|
||||
ChatModelOptions.objects.create(
|
||||
chat_model=offline_chat.get("chat-model"),
|
||||
ChatModel.objects.create(
|
||||
name=offline_chat.get("chat-model"),
|
||||
tokenizer=processor_conversation.get("tokenizer"),
|
||||
max_prompt_size=processor_conversation.get("max-prompt-size"),
|
||||
model_type=ChatModelOptions.ModelType.OFFLINE,
|
||||
model_type=ChatModel.ModelType.OFFLINE,
|
||||
)
|
||||
|
||||
if (
|
||||
|
@ -119,11 +119,11 @@ def migrate_server_pg(args):
|
|||
|
||||
openai_model_api = AiModelApi.objects.create(api_key=openai.get("api-key"), name="default")
|
||||
|
||||
ChatModelOptions.objects.create(
|
||||
chat_model=openai.get("chat-model"),
|
||||
ChatModel.objects.create(
|
||||
name=openai.get("chat-model"),
|
||||
tokenizer=processor_conversation.get("tokenizer"),
|
||||
max_prompt_size=processor_conversation.get("max-prompt-size"),
|
||||
model_type=ChatModelOptions.ModelType.OPENAI,
|
||||
model_type=ChatModel.ModelType.OPENAI,
|
||||
ai_model_api=openai_model_api,
|
||||
)
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ from typing import Dict, List, Optional
|
|||
import pyjson5
|
||||
from langchain.schema import ChatMessage
|
||||
|
||||
from khoj.database.models import Agent, ChatModelOptions, KhojUser
|
||||
from khoj.database.models import Agent, ChatModel, KhojUser
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.anthropic.utils import (
|
||||
anthropic_chat_completion_with_backoff,
|
||||
|
@ -85,7 +85,7 @@ def extract_questions_anthropic(
|
|||
prompt = construct_structured_message(
|
||||
message=prompt,
|
||||
images=query_images,
|
||||
model_type=ChatModelOptions.ModelType.ANTHROPIC,
|
||||
model_type=ChatModel.ModelType.ANTHROPIC,
|
||||
vision_enabled=vision_enabled,
|
||||
attached_file_context=query_files,
|
||||
)
|
||||
|
@ -218,7 +218,7 @@ def converse_anthropic(
|
|||
tokenizer_name=tokenizer_name,
|
||||
query_images=query_images,
|
||||
vision_enabled=vision_available,
|
||||
model_type=ChatModelOptions.ModelType.ANTHROPIC,
|
||||
model_type=ChatModel.ModelType.ANTHROPIC,
|
||||
query_files=query_files,
|
||||
generated_files=generated_files,
|
||||
generated_asset_results=generated_asset_results,
|
||||
|
|
|
@ -5,7 +5,7 @@ from typing import Dict, List, Optional
|
|||
import pyjson5
|
||||
from langchain.schema import ChatMessage
|
||||
|
||||
from khoj.database.models import Agent, ChatModelOptions, KhojUser
|
||||
from khoj.database.models import Agent, ChatModel, KhojUser
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.google.utils import (
|
||||
format_messages_for_gemini,
|
||||
|
@ -86,7 +86,7 @@ def extract_questions_gemini(
|
|||
prompt = construct_structured_message(
|
||||
message=prompt,
|
||||
images=query_images,
|
||||
model_type=ChatModelOptions.ModelType.GOOGLE,
|
||||
model_type=ChatModel.ModelType.GOOGLE,
|
||||
vision_enabled=vision_enabled,
|
||||
attached_file_context=query_files,
|
||||
)
|
||||
|
@ -229,7 +229,7 @@ def converse_gemini(
|
|||
tokenizer_name=tokenizer_name,
|
||||
query_images=query_images,
|
||||
vision_enabled=vision_available,
|
||||
model_type=ChatModelOptions.ModelType.GOOGLE,
|
||||
model_type=ChatModel.ModelType.GOOGLE,
|
||||
query_files=query_files,
|
||||
generated_files=generated_files,
|
||||
generated_asset_results=generated_asset_results,
|
||||
|
|
|
@ -9,7 +9,7 @@ import pyjson5
|
|||
from langchain.schema import ChatMessage
|
||||
from llama_cpp import Llama
|
||||
|
||||
from khoj.database.models import Agent, ChatModelOptions, KhojUser
|
||||
from khoj.database.models import Agent, ChatModel, KhojUser
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.offline.utils import download_model
|
||||
from khoj.processor.conversation.utils import (
|
||||
|
@ -96,7 +96,7 @@ def extract_questions_offline(
|
|||
model_name=model,
|
||||
loaded_model=offline_chat_model,
|
||||
max_prompt_size=max_prompt_size,
|
||||
model_type=ChatModelOptions.ModelType.OFFLINE,
|
||||
model_type=ChatModel.ModelType.OFFLINE,
|
||||
query_files=query_files,
|
||||
)
|
||||
|
||||
|
@ -232,7 +232,7 @@ def converse_offline(
|
|||
loaded_model=offline_chat_model,
|
||||
max_prompt_size=max_prompt_size,
|
||||
tokenizer_name=tokenizer_name,
|
||||
model_type=ChatModelOptions.ModelType.OFFLINE,
|
||||
model_type=ChatModel.ModelType.OFFLINE,
|
||||
query_files=query_files,
|
||||
generated_files=generated_files,
|
||||
generated_asset_results=generated_asset_results,
|
||||
|
|
|
@ -5,7 +5,7 @@ from typing import Dict, List, Optional
|
|||
import pyjson5
|
||||
from langchain.schema import ChatMessage
|
||||
|
||||
from khoj.database.models import Agent, ChatModelOptions, KhojUser
|
||||
from khoj.database.models import Agent, ChatModel, KhojUser
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.openai.utils import (
|
||||
chat_completion_with_backoff,
|
||||
|
@ -83,7 +83,7 @@ def extract_questions(
|
|||
prompt = construct_structured_message(
|
||||
message=prompt,
|
||||
images=query_images,
|
||||
model_type=ChatModelOptions.ModelType.OPENAI,
|
||||
model_type=ChatModel.ModelType.OPENAI,
|
||||
vision_enabled=vision_enabled,
|
||||
attached_file_context=query_files,
|
||||
)
|
||||
|
@ -220,7 +220,7 @@ def converse_openai(
|
|||
tokenizer_name=tokenizer_name,
|
||||
query_images=query_images,
|
||||
vision_enabled=vision_available,
|
||||
model_type=ChatModelOptions.ModelType.OPENAI,
|
||||
model_type=ChatModel.ModelType.OPENAI,
|
||||
query_files=query_files,
|
||||
generated_files=generated_files,
|
||||
generated_asset_results=generated_asset_results,
|
||||
|
|
|
@ -24,7 +24,7 @@ from llama_cpp.llama import Llama
|
|||
from transformers import AutoTokenizer
|
||||
|
||||
from khoj.database.adapters import ConversationAdapters
|
||||
from khoj.database.models import ChatModelOptions, ClientApplication, KhojUser
|
||||
from khoj.database.models import ChatModel, ClientApplication, KhojUser
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.offline.utils import download_model, infer_max_tokens
|
||||
from khoj.search_filter.base_filter import BaseFilter
|
||||
|
@ -330,9 +330,9 @@ def construct_structured_message(
|
|||
Format messages into appropriate multimedia format for supported chat model types
|
||||
"""
|
||||
if model_type in [
|
||||
ChatModelOptions.ModelType.OPENAI,
|
||||
ChatModelOptions.ModelType.GOOGLE,
|
||||
ChatModelOptions.ModelType.ANTHROPIC,
|
||||
ChatModel.ModelType.OPENAI,
|
||||
ChatModel.ModelType.GOOGLE,
|
||||
ChatModel.ModelType.ANTHROPIC,
|
||||
]:
|
||||
if not attached_file_context and not (vision_enabled and images):
|
||||
return message
|
||||
|
|
|
@ -28,12 +28,7 @@ from khoj.database.adapters import (
|
|||
get_default_search_model,
|
||||
get_user_photo,
|
||||
)
|
||||
from khoj.database.models import (
|
||||
Agent,
|
||||
ChatModelOptions,
|
||||
KhojUser,
|
||||
SpeechToTextModelOptions,
|
||||
)
|
||||
from khoj.database.models import Agent, ChatModel, KhojUser, SpeechToTextModelOptions
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.anthropic.anthropic_chat import (
|
||||
extract_questions_anthropic,
|
||||
|
@ -404,15 +399,15 @@ async def extract_references_and_questions(
|
|||
# Infer search queries from user message
|
||||
with timer("Extracting search queries took", logger):
|
||||
# If we've reached here, either the user has enabled offline chat or the openai model is enabled.
|
||||
conversation_config = await ConversationAdapters.aget_default_conversation_config(user)
|
||||
vision_enabled = conversation_config.vision_enabled
|
||||
chat_model = await ConversationAdapters.aget_default_chat_model(user)
|
||||
vision_enabled = chat_model.vision_enabled
|
||||
|
||||
if conversation_config.model_type == ChatModelOptions.ModelType.OFFLINE:
|
||||
if chat_model.model_type == ChatModel.ModelType.OFFLINE:
|
||||
using_offline_chat = True
|
||||
chat_model = conversation_config.chat_model
|
||||
max_tokens = conversation_config.max_prompt_size
|
||||
chat_model_name = chat_model.name
|
||||
max_tokens = chat_model.max_prompt_size
|
||||
if state.offline_chat_processor_config is None:
|
||||
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
|
||||
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model_name, max_tokens)
|
||||
|
||||
loaded_model = state.offline_chat_processor_config.loaded_model
|
||||
|
||||
|
@ -424,18 +419,18 @@ async def extract_references_and_questions(
|
|||
should_extract_questions=True,
|
||||
location_data=location_data,
|
||||
user=user,
|
||||
max_prompt_size=conversation_config.max_prompt_size,
|
||||
max_prompt_size=chat_model.max_prompt_size,
|
||||
personality_context=personality_context,
|
||||
query_files=query_files,
|
||||
tracer=tracer,
|
||||
)
|
||||
elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
|
||||
api_key = conversation_config.ai_model_api.api_key
|
||||
base_url = conversation_config.ai_model_api.api_base_url
|
||||
chat_model = conversation_config.chat_model
|
||||
elif chat_model.model_type == ChatModel.ModelType.OPENAI:
|
||||
api_key = chat_model.ai_model_api.api_key
|
||||
base_url = chat_model.ai_model_api.api_base_url
|
||||
chat_model_name = chat_model.name
|
||||
inferred_queries = extract_questions(
|
||||
defiltered_query,
|
||||
model=chat_model,
|
||||
model=chat_model_name,
|
||||
api_key=api_key,
|
||||
api_base_url=base_url,
|
||||
conversation_log=meta_log,
|
||||
|
@ -447,13 +442,13 @@ async def extract_references_and_questions(
|
|||
query_files=query_files,
|
||||
tracer=tracer,
|
||||
)
|
||||
elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC:
|
||||
api_key = conversation_config.ai_model_api.api_key
|
||||
chat_model = conversation_config.chat_model
|
||||
elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC:
|
||||
api_key = chat_model.ai_model_api.api_key
|
||||
chat_model_name = chat_model.name
|
||||
inferred_queries = extract_questions_anthropic(
|
||||
defiltered_query,
|
||||
query_images=query_images,
|
||||
model=chat_model,
|
||||
model=chat_model_name,
|
||||
api_key=api_key,
|
||||
conversation_log=meta_log,
|
||||
location_data=location_data,
|
||||
|
@ -463,17 +458,17 @@ async def extract_references_and_questions(
|
|||
query_files=query_files,
|
||||
tracer=tracer,
|
||||
)
|
||||
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
|
||||
api_key = conversation_config.ai_model_api.api_key
|
||||
chat_model = conversation_config.chat_model
|
||||
elif chat_model.model_type == ChatModel.ModelType.GOOGLE:
|
||||
api_key = chat_model.ai_model_api.api_key
|
||||
chat_model_name = chat_model.name
|
||||
inferred_queries = extract_questions_gemini(
|
||||
defiltered_query,
|
||||
query_images=query_images,
|
||||
model=chat_model,
|
||||
model=chat_model_name,
|
||||
api_key=api_key,
|
||||
conversation_log=meta_log,
|
||||
location_data=location_data,
|
||||
max_tokens=conversation_config.max_prompt_size,
|
||||
max_tokens=chat_model.max_prompt_size,
|
||||
user=user,
|
||||
vision_enabled=vision_enabled,
|
||||
personality_context=personality_context,
|
||||
|
|
|
@ -62,7 +62,7 @@ async def all_agents(
|
|||
"color": agent.style_color,
|
||||
"icon": agent.style_icon,
|
||||
"privacy_level": agent.privacy_level,
|
||||
"chat_model": agent.chat_model.chat_model,
|
||||
"chat_model": agent.chat_model.name,
|
||||
"files": file_names,
|
||||
"input_tools": agent.input_tools,
|
||||
"output_modes": agent.output_modes,
|
||||
|
@ -150,7 +150,7 @@ async def get_agent(
|
|||
"color": agent.style_color,
|
||||
"icon": agent.style_icon,
|
||||
"privacy_level": agent.privacy_level,
|
||||
"chat_model": agent.chat_model.chat_model,
|
||||
"chat_model": agent.chat_model.name,
|
||||
"files": file_names,
|
||||
"input_tools": agent.input_tools,
|
||||
"output_modes": agent.output_modes,
|
||||
|
@ -225,7 +225,7 @@ async def create_agent(
|
|||
"color": agent.style_color,
|
||||
"icon": agent.style_icon,
|
||||
"privacy_level": agent.privacy_level,
|
||||
"chat_model": agent.chat_model.chat_model,
|
||||
"chat_model": agent.chat_model.name,
|
||||
"files": body.files,
|
||||
"input_tools": agent.input_tools,
|
||||
"output_modes": agent.output_modes,
|
||||
|
@ -286,7 +286,7 @@ async def update_agent(
|
|||
"color": agent.style_color,
|
||||
"icon": agent.style_icon,
|
||||
"privacy_level": agent.privacy_level,
|
||||
"chat_model": agent.chat_model.chat_model,
|
||||
"chat_model": agent.chat_model.name,
|
||||
"files": body.files,
|
||||
"input_tools": agent.input_tools,
|
||||
"output_modes": agent.output_modes,
|
||||
|
|
|
@ -58,7 +58,7 @@ from khoj.routers.helpers import (
|
|||
is_ready_to_chat,
|
||||
read_chat_stream,
|
||||
update_telemetry_state,
|
||||
validate_conversation_config,
|
||||
validate_chat_model,
|
||||
)
|
||||
from khoj.routers.research import (
|
||||
InformationCollectionIteration,
|
||||
|
@ -205,7 +205,7 @@ def chat_history(
|
|||
n: Optional[int] = None,
|
||||
):
|
||||
user = request.user.object
|
||||
validate_conversation_config(user)
|
||||
validate_chat_model(user)
|
||||
|
||||
# Load Conversation History
|
||||
conversation = ConversationAdapters.get_conversation_by_user(
|
||||
|
@ -898,10 +898,10 @@ async def chat(
|
|||
custom_filters = []
|
||||
if conversation_commands == [ConversationCommand.Help]:
|
||||
if not q:
|
||||
conversation_config = await ConversationAdapters.aget_user_conversation_config(user)
|
||||
if conversation_config == None:
|
||||
conversation_config = await ConversationAdapters.aget_default_conversation_config(user)
|
||||
model_type = conversation_config.model_type
|
||||
chat_model = await ConversationAdapters.aget_user_chat_model(user)
|
||||
if chat_model == None:
|
||||
chat_model = await ConversationAdapters.aget_default_chat_model(user)
|
||||
model_type = chat_model.model_type
|
||||
formatted_help = help_message.format(model=model_type, version=state.khoj_version, device=get_device())
|
||||
async for result in send_llm_response(formatted_help, tracer.get("usage")):
|
||||
yield result
|
||||
|
|
|
@ -24,7 +24,7 @@ def get_chat_model_options(
|
|||
|
||||
all_conversation_options = list()
|
||||
for conversation_option in conversation_options:
|
||||
all_conversation_options.append({"chat_model": conversation_option.chat_model, "id": conversation_option.id})
|
||||
all_conversation_options.append({"chat_model": conversation_option.name, "id": conversation_option.id})
|
||||
|
||||
return Response(content=json.dumps(all_conversation_options), media_type="application/json", status_code=200)
|
||||
|
||||
|
@ -37,12 +37,12 @@ def get_user_chat_model(
|
|||
):
|
||||
user = request.user.object
|
||||
|
||||
chat_model = ConversationAdapters.get_conversation_config(user)
|
||||
chat_model = ConversationAdapters.get_chat_model(user)
|
||||
|
||||
if chat_model is None:
|
||||
chat_model = ConversationAdapters.get_default_conversation_config(user)
|
||||
chat_model = ConversationAdapters.get_default_chat_model(user)
|
||||
|
||||
return Response(status_code=200, content=json.dumps({"id": chat_model.id, "chat_model": chat_model.chat_model}))
|
||||
return Response(status_code=200, content=json.dumps({"id": chat_model.id, "chat_model": chat_model.name}))
|
||||
|
||||
|
||||
@api_model.post("/chat", status_code=200)
|
||||
|
|
|
@ -56,7 +56,7 @@ from khoj.database.adapters import (
|
|||
)
|
||||
from khoj.database.models import (
|
||||
Agent,
|
||||
ChatModelOptions,
|
||||
ChatModel,
|
||||
ClientApplication,
|
||||
Conversation,
|
||||
GithubConfig,
|
||||
|
@ -133,40 +133,40 @@ def is_query_empty(query: str) -> bool:
|
|||
return is_none_or_empty(query.strip())
|
||||
|
||||
|
||||
def validate_conversation_config(user: KhojUser):
|
||||
default_config = ConversationAdapters.get_default_conversation_config(user)
|
||||
def validate_chat_model(user: KhojUser):
|
||||
default_chat_model = ConversationAdapters.get_default_chat_model(user)
|
||||
|
||||
if default_config is None:
|
||||
if default_chat_model is None:
|
||||
raise HTTPException(status_code=500, detail="Contact the server administrator to add a chat model.")
|
||||
|
||||
if default_config.model_type == "openai" and not default_config.ai_model_api:
|
||||
if default_chat_model.model_type == "openai" and not default_chat_model.ai_model_api:
|
||||
raise HTTPException(status_code=500, detail="Contact the server administrator to add a chat model.")
|
||||
|
||||
|
||||
async def is_ready_to_chat(user: KhojUser):
|
||||
user_conversation_config = await ConversationAdapters.aget_user_conversation_config(user)
|
||||
if user_conversation_config == None:
|
||||
user_conversation_config = await ConversationAdapters.aget_default_conversation_config(user)
|
||||
user_chat_model = await ConversationAdapters.aget_user_chat_model(user)
|
||||
if user_chat_model == None:
|
||||
user_chat_model = await ConversationAdapters.aget_default_chat_model(user)
|
||||
|
||||
if user_conversation_config and user_conversation_config.model_type == ChatModelOptions.ModelType.OFFLINE:
|
||||
chat_model = user_conversation_config.chat_model
|
||||
max_tokens = user_conversation_config.max_prompt_size
|
||||
if user_chat_model and user_chat_model.model_type == ChatModel.ModelType.OFFLINE:
|
||||
chat_model_name = user_chat_model.name
|
||||
max_tokens = user_chat_model.max_prompt_size
|
||||
if state.offline_chat_processor_config is None:
|
||||
logger.info("Loading Offline Chat Model...")
|
||||
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
|
||||
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model_name, max_tokens)
|
||||
return True
|
||||
|
||||
if (
|
||||
user_conversation_config
|
||||
user_chat_model
|
||||
and (
|
||||
user_conversation_config.model_type
|
||||
user_chat_model.model_type
|
||||
in [
|
||||
ChatModelOptions.ModelType.OPENAI,
|
||||
ChatModelOptions.ModelType.ANTHROPIC,
|
||||
ChatModelOptions.ModelType.GOOGLE,
|
||||
ChatModel.ModelType.OPENAI,
|
||||
ChatModel.ModelType.ANTHROPIC,
|
||||
ChatModel.ModelType.GOOGLE,
|
||||
]
|
||||
)
|
||||
and user_conversation_config.ai_model_api
|
||||
and user_chat_model.ai_model_api
|
||||
):
|
||||
return True
|
||||
|
||||
|
@ -942,120 +942,124 @@ async def send_message_to_model_wrapper(
|
|||
query_files: str = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
conversation_config: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config(user)
|
||||
vision_available = conversation_config.vision_enabled
|
||||
chat_model: ChatModel = await ConversationAdapters.aget_default_chat_model(user)
|
||||
vision_available = chat_model.vision_enabled
|
||||
if not vision_available and query_images:
|
||||
logger.warning(f"Vision is not enabled for default model: {conversation_config.chat_model}.")
|
||||
logger.warning(f"Vision is not enabled for default model: {chat_model.name}.")
|
||||
vision_enabled_config = await ConversationAdapters.aget_vision_enabled_config()
|
||||
if vision_enabled_config:
|
||||
conversation_config = vision_enabled_config
|
||||
chat_model = vision_enabled_config
|
||||
vision_available = True
|
||||
if vision_available and query_images:
|
||||
logger.info(f"Using {conversation_config.chat_model} model to understand {len(query_images)} images.")
|
||||
logger.info(f"Using {chat_model.name} model to understand {len(query_images)} images.")
|
||||
|
||||
subscribed = await ais_user_subscribed(user)
|
||||
chat_model = conversation_config.chat_model
|
||||
chat_model_name = chat_model.name
|
||||
max_tokens = (
|
||||
conversation_config.subscribed_max_prompt_size
|
||||
if subscribed and conversation_config.subscribed_max_prompt_size
|
||||
else conversation_config.max_prompt_size
|
||||
chat_model.subscribed_max_prompt_size
|
||||
if subscribed and chat_model.subscribed_max_prompt_size
|
||||
else chat_model.max_prompt_size
|
||||
)
|
||||
tokenizer = conversation_config.tokenizer
|
||||
model_type = conversation_config.model_type
|
||||
vision_available = conversation_config.vision_enabled
|
||||
tokenizer = chat_model.tokenizer
|
||||
model_type = chat_model.model_type
|
||||
vision_available = chat_model.vision_enabled
|
||||
|
||||
if model_type == ChatModelOptions.ModelType.OFFLINE:
|
||||
if model_type == ChatModel.ModelType.OFFLINE:
|
||||
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
|
||||
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
|
||||
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model_name, max_tokens)
|
||||
|
||||
loaded_model = state.offline_chat_processor_config.loaded_model
|
||||
truncated_messages = generate_chatml_messages_with_context(
|
||||
user_message=query,
|
||||
context_message=context,
|
||||
system_message=system_message,
|
||||
model_name=chat_model,
|
||||
model_name=chat_model_name,
|
||||
loaded_model=loaded_model,
|
||||
tokenizer_name=tokenizer,
|
||||
max_prompt_size=max_tokens,
|
||||
vision_enabled=vision_available,
|
||||
model_type=conversation_config.model_type,
|
||||
model_type=chat_model.model_type,
|
||||
query_files=query_files,
|
||||
)
|
||||
|
||||
return send_message_to_model_offline(
|
||||
messages=truncated_messages,
|
||||
loaded_model=loaded_model,
|
||||
model=chat_model,
|
||||
model=chat_model_name,
|
||||
max_prompt_size=max_tokens,
|
||||
streaming=False,
|
||||
response_type=response_type,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
elif model_type == ChatModelOptions.ModelType.OPENAI:
|
||||
openai_chat_config = conversation_config.ai_model_api
|
||||
elif model_type == ChatModel.ModelType.OPENAI:
|
||||
openai_chat_config = chat_model.ai_model_api
|
||||
api_key = openai_chat_config.api_key
|
||||
api_base_url = openai_chat_config.api_base_url
|
||||
truncated_messages = generate_chatml_messages_with_context(
|
||||
user_message=query,
|
||||
context_message=context,
|
||||
system_message=system_message,
|
||||
model_name=chat_model,
|
||||
model_name=chat_model_name,
|
||||
max_prompt_size=max_tokens,
|
||||
tokenizer_name=tokenizer,
|
||||
vision_enabled=vision_available,
|
||||
query_images=query_images,
|
||||
model_type=conversation_config.model_type,
|
||||
model_type=chat_model.model_type,
|
||||
query_files=query_files,
|
||||
)
|
||||
|
||||
return send_message_to_model(
|
||||
messages=truncated_messages,
|
||||
api_key=api_key,
|
||||
model=chat_model,
|
||||
model=chat_model_name,
|
||||
response_type=response_type,
|
||||
api_base_url=api_base_url,
|
||||
tracer=tracer,
|
||||
)
|
||||
elif model_type == ChatModelOptions.ModelType.ANTHROPIC:
|
||||
api_key = conversation_config.ai_model_api.api_key
|
||||
elif model_type == ChatModel.ModelType.ANTHROPIC:
|
||||
api_key = chat_model.ai_model_api.api_key
|
||||
truncated_messages = generate_chatml_messages_with_context(
|
||||
user_message=query,
|
||||
context_message=context,
|
||||
system_message=system_message,
|
||||
model_name=chat_model,
|
||||
model_name=chat_model_name,
|
||||
max_prompt_size=max_tokens,
|
||||
tokenizer_name=tokenizer,
|
||||
vision_enabled=vision_available,
|
||||
query_images=query_images,
|
||||
model_type=conversation_config.model_type,
|
||||
model_type=chat_model.model_type,
|
||||
query_files=query_files,
|
||||
)
|
||||
|
||||
return anthropic_send_message_to_model(
|
||||
messages=truncated_messages,
|
||||
api_key=api_key,
|
||||
model=chat_model,
|
||||
model=chat_model_name,
|
||||
response_type=response_type,
|
||||
tracer=tracer,
|
||||
)
|
||||
elif model_type == ChatModelOptions.ModelType.GOOGLE:
|
||||
api_key = conversation_config.ai_model_api.api_key
|
||||
elif model_type == ChatModel.ModelType.GOOGLE:
|
||||
api_key = chat_model.ai_model_api.api_key
|
||||
truncated_messages = generate_chatml_messages_with_context(
|
||||
user_message=query,
|
||||
context_message=context,
|
||||
system_message=system_message,
|
||||
model_name=chat_model,
|
||||
model_name=chat_model_name,
|
||||
max_prompt_size=max_tokens,
|
||||
tokenizer_name=tokenizer,
|
||||
vision_enabled=vision_available,
|
||||
query_images=query_images,
|
||||
model_type=conversation_config.model_type,
|
||||
model_type=chat_model.model_type,
|
||||
query_files=query_files,
|
||||
)
|
||||
|
||||
return gemini_send_message_to_model(
|
||||
messages=truncated_messages, api_key=api_key, model=chat_model, response_type=response_type, tracer=tracer
|
||||
messages=truncated_messages,
|
||||
api_key=api_key,
|
||||
model=chat_model_name,
|
||||
response_type=response_type,
|
||||
tracer=tracer,
|
||||
)
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail="Invalid conversation config")
|
||||
|
@ -1069,99 +1073,99 @@ def send_message_to_model_wrapper_sync(
|
|||
query_files: str = "",
|
||||
tracer: dict = {},
|
||||
):
|
||||
conversation_config: ChatModelOptions = ConversationAdapters.get_default_conversation_config(user)
|
||||
chat_model: ChatModel = ConversationAdapters.get_default_chat_model(user)
|
||||
|
||||
if conversation_config is None:
|
||||
if chat_model is None:
|
||||
raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.")
|
||||
|
||||
chat_model = conversation_config.chat_model
|
||||
max_tokens = conversation_config.max_prompt_size
|
||||
vision_available = conversation_config.vision_enabled
|
||||
chat_model_name = chat_model.name
|
||||
max_tokens = chat_model.max_prompt_size
|
||||
vision_available = chat_model.vision_enabled
|
||||
|
||||
if conversation_config.model_type == ChatModelOptions.ModelType.OFFLINE:
|
||||
if chat_model.model_type == ChatModel.ModelType.OFFLINE:
|
||||
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
|
||||
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
|
||||
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model_name, max_tokens)
|
||||
|
||||
loaded_model = state.offline_chat_processor_config.loaded_model
|
||||
truncated_messages = generate_chatml_messages_with_context(
|
||||
user_message=message,
|
||||
system_message=system_message,
|
||||
model_name=chat_model,
|
||||
model_name=chat_model_name,
|
||||
loaded_model=loaded_model,
|
||||
max_prompt_size=max_tokens,
|
||||
vision_enabled=vision_available,
|
||||
model_type=conversation_config.model_type,
|
||||
model_type=chat_model.model_type,
|
||||
query_files=query_files,
|
||||
)
|
||||
|
||||
return send_message_to_model_offline(
|
||||
messages=truncated_messages,
|
||||
loaded_model=loaded_model,
|
||||
model=chat_model,
|
||||
model=chat_model_name,
|
||||
max_prompt_size=max_tokens,
|
||||
streaming=False,
|
||||
response_type=response_type,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
|
||||
api_key = conversation_config.ai_model_api.api_key
|
||||
elif chat_model.model_type == ChatModel.ModelType.OPENAI:
|
||||
api_key = chat_model.ai_model_api.api_key
|
||||
truncated_messages = generate_chatml_messages_with_context(
|
||||
user_message=message,
|
||||
system_message=system_message,
|
||||
model_name=chat_model,
|
||||
model_name=chat_model_name,
|
||||
max_prompt_size=max_tokens,
|
||||
vision_enabled=vision_available,
|
||||
model_type=conversation_config.model_type,
|
||||
model_type=chat_model.model_type,
|
||||
query_files=query_files,
|
||||
)
|
||||
|
||||
openai_response = send_message_to_model(
|
||||
messages=truncated_messages,
|
||||
api_key=api_key,
|
||||
model=chat_model,
|
||||
model=chat_model_name,
|
||||
response_type=response_type,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
return openai_response
|
||||
|
||||
elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC:
|
||||
api_key = conversation_config.ai_model_api.api_key
|
||||
elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC:
|
||||
api_key = chat_model.ai_model_api.api_key
|
||||
truncated_messages = generate_chatml_messages_with_context(
|
||||
user_message=message,
|
||||
system_message=system_message,
|
||||
model_name=chat_model,
|
||||
model_name=chat_model_name,
|
||||
max_prompt_size=max_tokens,
|
||||
vision_enabled=vision_available,
|
||||
model_type=conversation_config.model_type,
|
||||
model_type=chat_model.model_type,
|
||||
query_files=query_files,
|
||||
)
|
||||
|
||||
return anthropic_send_message_to_model(
|
||||
messages=truncated_messages,
|
||||
api_key=api_key,
|
||||
model=chat_model,
|
||||
model=chat_model_name,
|
||||
response_type=response_type,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
|
||||
api_key = conversation_config.ai_model_api.api_key
|
||||
elif chat_model.model_type == ChatModel.ModelType.GOOGLE:
|
||||
api_key = chat_model.ai_model_api.api_key
|
||||
truncated_messages = generate_chatml_messages_with_context(
|
||||
user_message=message,
|
||||
system_message=system_message,
|
||||
model_name=chat_model,
|
||||
model_name=chat_model_name,
|
||||
max_prompt_size=max_tokens,
|
||||
vision_enabled=vision_available,
|
||||
model_type=conversation_config.model_type,
|
||||
model_type=chat_model.model_type,
|
||||
query_files=query_files,
|
||||
)
|
||||
|
||||
return gemini_send_message_to_model(
|
||||
messages=truncated_messages,
|
||||
api_key=api_key,
|
||||
model=chat_model,
|
||||
model=chat_model_name,
|
||||
response_type=response_type,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
@ -1229,15 +1233,15 @@ def generate_chat_response(
|
|||
online_results = {}
|
||||
code_results = {}
|
||||
|
||||
conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation)
|
||||
vision_available = conversation_config.vision_enabled
|
||||
chat_model = ConversationAdapters.get_valid_chat_model(user, conversation)
|
||||
vision_available = chat_model.vision_enabled
|
||||
if not vision_available and query_images:
|
||||
vision_enabled_config = ConversationAdapters.get_vision_enabled_config()
|
||||
if vision_enabled_config:
|
||||
conversation_config = vision_enabled_config
|
||||
chat_model = vision_enabled_config
|
||||
vision_available = True
|
||||
|
||||
if conversation_config.model_type == "offline":
|
||||
if chat_model.model_type == "offline":
|
||||
loaded_model = state.offline_chat_processor_config.loaded_model
|
||||
chat_response = converse_offline(
|
||||
user_query=query_to_run,
|
||||
|
@ -1247,9 +1251,9 @@ def generate_chat_response(
|
|||
conversation_log=meta_log,
|
||||
completion_func=partial_completion,
|
||||
conversation_commands=conversation_commands,
|
||||
model=conversation_config.chat_model,
|
||||
max_prompt_size=conversation_config.max_prompt_size,
|
||||
tokenizer_name=conversation_config.tokenizer,
|
||||
model=chat_model.name,
|
||||
max_prompt_size=chat_model.max_prompt_size,
|
||||
tokenizer_name=chat_model.tokenizer,
|
||||
location_data=location_data,
|
||||
user_name=user_name,
|
||||
agent=agent,
|
||||
|
@ -1259,10 +1263,10 @@ def generate_chat_response(
|
|||
tracer=tracer,
|
||||
)
|
||||
|
||||
elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
|
||||
openai_chat_config = conversation_config.ai_model_api
|
||||
elif chat_model.model_type == ChatModel.ModelType.OPENAI:
|
||||
openai_chat_config = chat_model.ai_model_api
|
||||
api_key = openai_chat_config.api_key
|
||||
chat_model = conversation_config.chat_model
|
||||
chat_model_name = chat_model.name
|
||||
chat_response = converse_openai(
|
||||
compiled_references,
|
||||
query_to_run,
|
||||
|
@ -1270,13 +1274,13 @@ def generate_chat_response(
|
|||
online_results=online_results,
|
||||
code_results=code_results,
|
||||
conversation_log=meta_log,
|
||||
model=chat_model,
|
||||
model=chat_model_name,
|
||||
api_key=api_key,
|
||||
api_base_url=openai_chat_config.api_base_url,
|
||||
completion_func=partial_completion,
|
||||
conversation_commands=conversation_commands,
|
||||
max_prompt_size=conversation_config.max_prompt_size,
|
||||
tokenizer_name=conversation_config.tokenizer,
|
||||
max_prompt_size=chat_model.max_prompt_size,
|
||||
tokenizer_name=chat_model.tokenizer,
|
||||
location_data=location_data,
|
||||
user_name=user_name,
|
||||
agent=agent,
|
||||
|
@ -1288,8 +1292,8 @@ def generate_chat_response(
|
|||
tracer=tracer,
|
||||
)
|
||||
|
||||
elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC:
|
||||
api_key = conversation_config.ai_model_api.api_key
|
||||
elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC:
|
||||
api_key = chat_model.ai_model_api.api_key
|
||||
chat_response = converse_anthropic(
|
||||
compiled_references,
|
||||
query_to_run,
|
||||
|
@ -1297,12 +1301,12 @@ def generate_chat_response(
|
|||
online_results=online_results,
|
||||
code_results=code_results,
|
||||
conversation_log=meta_log,
|
||||
model=conversation_config.chat_model,
|
||||
model=chat_model.name,
|
||||
api_key=api_key,
|
||||
completion_func=partial_completion,
|
||||
conversation_commands=conversation_commands,
|
||||
max_prompt_size=conversation_config.max_prompt_size,
|
||||
tokenizer_name=conversation_config.tokenizer,
|
||||
max_prompt_size=chat_model.max_prompt_size,
|
||||
tokenizer_name=chat_model.tokenizer,
|
||||
location_data=location_data,
|
||||
user_name=user_name,
|
||||
agent=agent,
|
||||
|
@ -1313,20 +1317,20 @@ def generate_chat_response(
|
|||
program_execution_context=program_execution_context,
|
||||
tracer=tracer,
|
||||
)
|
||||
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
|
||||
api_key = conversation_config.ai_model_api.api_key
|
||||
elif chat_model.model_type == ChatModel.ModelType.GOOGLE:
|
||||
api_key = chat_model.ai_model_api.api_key
|
||||
chat_response = converse_gemini(
|
||||
compiled_references,
|
||||
query_to_run,
|
||||
online_results,
|
||||
code_results,
|
||||
meta_log,
|
||||
model=conversation_config.chat_model,
|
||||
model=chat_model.name,
|
||||
api_key=api_key,
|
||||
completion_func=partial_completion,
|
||||
conversation_commands=conversation_commands,
|
||||
max_prompt_size=conversation_config.max_prompt_size,
|
||||
tokenizer_name=conversation_config.tokenizer,
|
||||
max_prompt_size=chat_model.max_prompt_size,
|
||||
tokenizer_name=chat_model.tokenizer,
|
||||
location_data=location_data,
|
||||
user_name=user_name,
|
||||
agent=agent,
|
||||
|
@ -1339,7 +1343,7 @@ def generate_chat_response(
|
|||
tracer=tracer,
|
||||
)
|
||||
|
||||
metadata.update({"chat_model": conversation_config.chat_model})
|
||||
metadata.update({"chat_model": chat_model.name})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(e, exc_info=True)
|
||||
|
@ -1939,13 +1943,13 @@ def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False)
|
|||
current_notion_config = get_user_notion_config(user)
|
||||
notion_token = current_notion_config.token if current_notion_config else ""
|
||||
|
||||
selected_chat_model_config = ConversationAdapters.get_conversation_config(
|
||||
selected_chat_model_config = ConversationAdapters.get_chat_model(
|
||||
user
|
||||
) or ConversationAdapters.get_default_conversation_config(user)
|
||||
) or ConversationAdapters.get_default_chat_model(user)
|
||||
chat_models = ConversationAdapters.get_conversation_processor_options().all()
|
||||
chat_model_options = list()
|
||||
for chat_model in chat_models:
|
||||
chat_model_options.append({"name": chat_model.chat_model, "id": chat_model.id})
|
||||
chat_model_options.append({"name": chat_model.name, "id": chat_model.id})
|
||||
|
||||
selected_paint_model_config = ConversationAdapters.get_user_text_to_image_model_config(user)
|
||||
paint_model_options = ConversationAdapters.get_text_to_image_model_options().all()
|
||||
|
|
|
@ -7,7 +7,7 @@ import openai
|
|||
from khoj.database.adapters import ConversationAdapters
|
||||
from khoj.database.models import (
|
||||
AiModelApi,
|
||||
ChatModelOptions,
|
||||
ChatModel,
|
||||
KhojUser,
|
||||
SpeechToTextModelOptions,
|
||||
TextToImageModelConfig,
|
||||
|
@ -63,7 +63,7 @@ def initialization(interactive: bool = True):
|
|||
|
||||
# Set up OpenAI's online chat models
|
||||
openai_configured, openai_provider = _setup_chat_model_provider(
|
||||
ChatModelOptions.ModelType.OPENAI,
|
||||
ChatModel.ModelType.OPENAI,
|
||||
default_chat_models,
|
||||
default_api_key=openai_api_key,
|
||||
api_base_url=openai_api_base,
|
||||
|
@ -105,7 +105,7 @@ def initialization(interactive: bool = True):
|
|||
|
||||
# Set up Google's Gemini online chat models
|
||||
_setup_chat_model_provider(
|
||||
ChatModelOptions.ModelType.GOOGLE,
|
||||
ChatModel.ModelType.GOOGLE,
|
||||
default_gemini_chat_models,
|
||||
default_api_key=os.getenv("GEMINI_API_KEY"),
|
||||
vision_enabled=True,
|
||||
|
@ -116,7 +116,7 @@ def initialization(interactive: bool = True):
|
|||
|
||||
# Set up Anthropic's online chat models
|
||||
_setup_chat_model_provider(
|
||||
ChatModelOptions.ModelType.ANTHROPIC,
|
||||
ChatModel.ModelType.ANTHROPIC,
|
||||
default_anthropic_chat_models,
|
||||
default_api_key=os.getenv("ANTHROPIC_API_KEY"),
|
||||
vision_enabled=True,
|
||||
|
@ -126,7 +126,7 @@ def initialization(interactive: bool = True):
|
|||
|
||||
# Set up offline chat models
|
||||
_setup_chat_model_provider(
|
||||
ChatModelOptions.ModelType.OFFLINE,
|
||||
ChatModel.ModelType.OFFLINE,
|
||||
default_offline_chat_models,
|
||||
default_api_key=None,
|
||||
vision_enabled=False,
|
||||
|
@ -135,9 +135,9 @@ def initialization(interactive: bool = True):
|
|||
)
|
||||
|
||||
# Explicitly set default chat model
|
||||
chat_models_configured = ChatModelOptions.objects.count()
|
||||
chat_models_configured = ChatModel.objects.count()
|
||||
if chat_models_configured > 0:
|
||||
default_chat_model_name = ChatModelOptions.objects.first().chat_model
|
||||
default_chat_model_name = ChatModel.objects.first().name
|
||||
# If there are multiple chat models, ask the user to choose the default chat model
|
||||
if chat_models_configured > 1 and interactive:
|
||||
user_chat_model_name = input(
|
||||
|
@ -147,7 +147,7 @@ def initialization(interactive: bool = True):
|
|||
user_chat_model_name = None
|
||||
|
||||
# If the user's choice is valid, set it as the default chat model
|
||||
if user_chat_model_name and ChatModelOptions.objects.filter(chat_model=user_chat_model_name).exists():
|
||||
if user_chat_model_name and ChatModel.objects.filter(name=user_chat_model_name).exists():
|
||||
default_chat_model_name = user_chat_model_name
|
||||
|
||||
logger.info("🗣️ Chat model configuration complete")
|
||||
|
@ -171,7 +171,7 @@ def initialization(interactive: bool = True):
|
|||
logger.info(f"🗣️ Offline speech to text model configured to {offline_speech2text_model}")
|
||||
|
||||
def _setup_chat_model_provider(
|
||||
model_type: ChatModelOptions.ModelType,
|
||||
model_type: ChatModel.ModelType,
|
||||
default_chat_models: list,
|
||||
default_api_key: str,
|
||||
interactive: bool,
|
||||
|
@ -226,7 +226,7 @@ def initialization(interactive: bool = True):
|
|||
"ai_model_api": ai_model_api,
|
||||
}
|
||||
|
||||
ChatModelOptions.objects.create(**chat_model_options)
|
||||
ChatModel.objects.create(**chat_model_options)
|
||||
|
||||
logger.info(f"🗣️ {provider_name} chat model configuration complete")
|
||||
return True, ai_model_api
|
||||
|
@ -250,16 +250,16 @@ def initialization(interactive: bool = True):
|
|||
available_models = [model.id for model in openai_client.models.list()]
|
||||
|
||||
# Get existing chat model options for this config
|
||||
existing_models = ChatModelOptions.objects.filter(
|
||||
ai_model_api=config, model_type=ChatModelOptions.ModelType.OPENAI
|
||||
existing_models = ChatModel.objects.filter(
|
||||
ai_model_api=config, model_type=ChatModel.ModelType.OPENAI
|
||||
)
|
||||
|
||||
# Add new models
|
||||
for model in available_models:
|
||||
if not existing_models.filter(chat_model=model).exists():
|
||||
ChatModelOptions.objects.create(
|
||||
chat_model=model,
|
||||
model_type=ChatModelOptions.ModelType.OPENAI,
|
||||
if not existing_models.filter(name=model).exists():
|
||||
ChatModel.objects.create(
|
||||
name=model,
|
||||
model_type=ChatModel.ModelType.OPENAI,
|
||||
max_prompt_size=model_to_prompt_size.get(model),
|
||||
vision_enabled=model in default_openai_chat_models,
|
||||
tokenizer=model_to_tokenizer.get(model),
|
||||
|
@ -284,7 +284,7 @@ def initialization(interactive: bool = True):
|
|||
except Exception as e:
|
||||
logger.error(f"🚨 Failed to create admin user: {e}", exc_info=True)
|
||||
|
||||
chat_config = ConversationAdapters.get_default_conversation_config()
|
||||
chat_config = ConversationAdapters.get_default_chat_model()
|
||||
if admin_user is None and chat_config is None:
|
||||
while True:
|
||||
try:
|
||||
|
|
|
@ -13,7 +13,7 @@ from khoj.configure import (
|
|||
)
|
||||
from khoj.database.models import (
|
||||
Agent,
|
||||
ChatModelOptions,
|
||||
ChatModel,
|
||||
GithubConfig,
|
||||
GithubRepoConfig,
|
||||
KhojApiUser,
|
||||
|
@ -35,7 +35,7 @@ from khoj.utils.helpers import resolve_absolute_path
|
|||
from khoj.utils.rawconfig import ContentConfig, ImageSearchConfig, SearchConfig
|
||||
from tests.helpers import (
|
||||
AiModelApiFactory,
|
||||
ChatModelOptionsFactory,
|
||||
ChatModelFactory,
|
||||
ProcessLockFactory,
|
||||
SubscriptionFactory,
|
||||
UserConversationProcessorConfigFactory,
|
||||
|
@ -184,14 +184,14 @@ def api_user4(default_user4):
|
|||
@pytest.mark.django_db
|
||||
@pytest.fixture
|
||||
def default_openai_chat_model_option():
|
||||
chat_model = ChatModelOptionsFactory(chat_model="gpt-4o-mini", model_type="openai")
|
||||
chat_model = ChatModelFactory(name="gpt-4o-mini", model_type="openai")
|
||||
return chat_model
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.fixture
|
||||
def offline_agent():
|
||||
chat_model = ChatModelOptionsFactory()
|
||||
chat_model = ChatModelFactory()
|
||||
return Agent.objects.create(
|
||||
name="Accountant",
|
||||
chat_model=chat_model,
|
||||
|
@ -202,7 +202,7 @@ def offline_agent():
|
|||
@pytest.mark.django_db
|
||||
@pytest.fixture
|
||||
def openai_agent():
|
||||
chat_model = ChatModelOptionsFactory(chat_model="gpt-4o-mini", model_type="openai")
|
||||
chat_model = ChatModelFactory(name="gpt-4o-mini", model_type="openai")
|
||||
return Agent.objects.create(
|
||||
name="Accountant",
|
||||
chat_model=chat_model,
|
||||
|
@ -311,13 +311,13 @@ def chat_client_builder(search_config, user, index_content=True, require_auth=Fa
|
|||
|
||||
# Initialize Processor from Config
|
||||
chat_provider = get_chat_provider()
|
||||
online_chat_model: ChatModelOptionsFactory = None
|
||||
if chat_provider == ChatModelOptions.ModelType.OPENAI:
|
||||
online_chat_model = ChatModelOptionsFactory(chat_model="gpt-4o-mini", model_type="openai")
|
||||
elif chat_provider == ChatModelOptions.ModelType.GOOGLE:
|
||||
online_chat_model = ChatModelOptionsFactory(chat_model="gemini-1.5-flash", model_type="google")
|
||||
elif chat_provider == ChatModelOptions.ModelType.ANTHROPIC:
|
||||
online_chat_model = ChatModelOptionsFactory(chat_model="claude-3-5-haiku-20241022", model_type="anthropic")
|
||||
online_chat_model: ChatModelFactory = None
|
||||
if chat_provider == ChatModel.ModelType.OPENAI:
|
||||
online_chat_model = ChatModelFactory(name="gpt-4o-mini", model_type="openai")
|
||||
elif chat_provider == ChatModel.ModelType.GOOGLE:
|
||||
online_chat_model = ChatModelFactory(name="gemini-1.5-flash", model_type="google")
|
||||
elif chat_provider == ChatModel.ModelType.ANTHROPIC:
|
||||
online_chat_model = ChatModelFactory(name="claude-3-5-haiku-20241022", model_type="anthropic")
|
||||
if online_chat_model:
|
||||
online_chat_model.ai_model_api = AiModelApiFactory(api_key=get_chat_api_key(chat_provider))
|
||||
UserConversationProcessorConfigFactory(user=user, setting=online_chat_model)
|
||||
|
@ -394,8 +394,8 @@ def client_offline_chat(search_config: SearchConfig, default_user2: KhojUser):
|
|||
configure_content(default_user2, all_files)
|
||||
|
||||
# Initialize Processor from Config
|
||||
ChatModelOptionsFactory(
|
||||
chat_model="bartowski/Meta-Llama-3.1-3B-Instruct-GGUF",
|
||||
ChatModelFactory(
|
||||
name="bartowski/Meta-Llama-3.1-3B-Instruct-GGUF",
|
||||
tokenizer=None,
|
||||
max_prompt_size=None,
|
||||
model_type="offline",
|
||||
|
|
|
@ -6,7 +6,7 @@ from django.utils.timezone import make_aware
|
|||
|
||||
from khoj.database.models import (
|
||||
AiModelApi,
|
||||
ChatModelOptions,
|
||||
ChatModel,
|
||||
Conversation,
|
||||
KhojApiUser,
|
||||
KhojUser,
|
||||
|
@ -18,27 +18,27 @@ from khoj.database.models import (
|
|||
from khoj.processor.conversation.utils import message_to_log
|
||||
|
||||
|
||||
def get_chat_provider(default: ChatModelOptions.ModelType | None = ChatModelOptions.ModelType.OFFLINE):
|
||||
def get_chat_provider(default: ChatModel.ModelType | None = ChatModel.ModelType.OFFLINE):
|
||||
provider = os.getenv("KHOJ_TEST_CHAT_PROVIDER")
|
||||
if provider and provider in ChatModelOptions.ModelType:
|
||||
return ChatModelOptions.ModelType(provider)
|
||||
if provider and provider in ChatModel.ModelType:
|
||||
return ChatModel.ModelType(provider)
|
||||
elif os.getenv("OPENAI_API_KEY"):
|
||||
return ChatModelOptions.ModelType.OPENAI
|
||||
return ChatModel.ModelType.OPENAI
|
||||
elif os.getenv("GEMINI_API_KEY"):
|
||||
return ChatModelOptions.ModelType.GOOGLE
|
||||
return ChatModel.ModelType.GOOGLE
|
||||
elif os.getenv("ANTHROPIC_API_KEY"):
|
||||
return ChatModelOptions.ModelType.ANTHROPIC
|
||||
return ChatModel.ModelType.ANTHROPIC
|
||||
else:
|
||||
return default
|
||||
|
||||
|
||||
def get_chat_api_key(provider: ChatModelOptions.ModelType = None):
|
||||
def get_chat_api_key(provider: ChatModel.ModelType = None):
|
||||
provider = provider or get_chat_provider()
|
||||
if provider == ChatModelOptions.ModelType.OPENAI:
|
||||
if provider == ChatModel.ModelType.OPENAI:
|
||||
return os.getenv("OPENAI_API_KEY")
|
||||
elif provider == ChatModelOptions.ModelType.GOOGLE:
|
||||
elif provider == ChatModel.ModelType.GOOGLE:
|
||||
return os.getenv("GEMINI_API_KEY")
|
||||
elif provider == ChatModelOptions.ModelType.ANTHROPIC:
|
||||
elif provider == ChatModel.ModelType.ANTHROPIC:
|
||||
return os.getenv("ANTHROPIC_API_KEY")
|
||||
else:
|
||||
return os.getenv("OPENAI_API_KEY") or os.getenv("GEMINI_API_KEY") or os.getenv("ANTHROPIC_API_KEY")
|
||||
|
@ -83,13 +83,13 @@ class AiModelApiFactory(factory.django.DjangoModelFactory):
|
|||
api_key = get_chat_api_key()
|
||||
|
||||
|
||||
class ChatModelOptionsFactory(factory.django.DjangoModelFactory):
|
||||
class ChatModelFactory(factory.django.DjangoModelFactory):
|
||||
class Meta:
|
||||
model = ChatModelOptions
|
||||
model = ChatModel
|
||||
|
||||
max_prompt_size = 20000
|
||||
tokenizer = None
|
||||
chat_model = "bartowski/Meta-Llama-3.2-3B-Instruct-GGUF"
|
||||
name = "bartowski/Meta-Llama-3.2-3B-Instruct-GGUF"
|
||||
model_type = get_chat_provider()
|
||||
ai_model_api = factory.LazyAttribute(lambda obj: AiModelApiFactory() if get_chat_api_key() else None)
|
||||
|
||||
|
@ -99,7 +99,7 @@ class UserConversationProcessorConfigFactory(factory.django.DjangoModelFactory):
|
|||
model = UserConversationConfig
|
||||
|
||||
user = factory.SubFactory(UserFactory)
|
||||
setting = factory.SubFactory(ChatModelOptionsFactory)
|
||||
setting = factory.SubFactory(ChatModelFactory)
|
||||
|
||||
|
||||
class ConversationFactory(factory.django.DjangoModelFactory):
|
||||
|
|
|
@ -5,14 +5,14 @@ import pytest
|
|||
from asgiref.sync import sync_to_async
|
||||
|
||||
from khoj.database.adapters import AgentAdapters
|
||||
from khoj.database.models import Agent, ChatModelOptions, Entry, KhojUser
|
||||
from khoj.database.models import Agent, ChatModel, Entry, KhojUser
|
||||
from khoj.routers.api import execute_search
|
||||
from khoj.utils.helpers import get_absolute_path
|
||||
from tests.helpers import ChatModelOptionsFactory
|
||||
from tests.helpers import ChatModelFactory
|
||||
|
||||
|
||||
def test_create_default_agent(default_user: KhojUser):
|
||||
ChatModelOptionsFactory()
|
||||
ChatModelFactory()
|
||||
|
||||
agent = AgentAdapters.create_default_agent(default_user)
|
||||
assert agent is not None
|
||||
|
@ -24,7 +24,7 @@ def test_create_default_agent(default_user: KhojUser):
|
|||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
async def test_create_or_update_agent(default_user: KhojUser, default_openai_chat_model_option: ChatModelOptions):
|
||||
async def test_create_or_update_agent(default_user: KhojUser, default_openai_chat_model_option: ChatModel):
|
||||
new_agent = await AgentAdapters.aupdate_agent(
|
||||
default_user,
|
||||
"Test Agent",
|
||||
|
@ -32,7 +32,7 @@ async def test_create_or_update_agent(default_user: KhojUser, default_openai_cha
|
|||
Agent.PrivacyLevel.PRIVATE,
|
||||
"icon",
|
||||
"color",
|
||||
default_openai_chat_model_option.chat_model,
|
||||
default_openai_chat_model_option.name,
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
|
@ -46,7 +46,7 @@ async def test_create_or_update_agent(default_user: KhojUser, default_openai_cha
|
|||
@pytest.mark.anyio
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
async def test_create_or_update_agent_with_knowledge_base(
|
||||
default_user2: KhojUser, default_openai_chat_model_option: ChatModelOptions, chat_client
|
||||
default_user2: KhojUser, default_openai_chat_model_option: ChatModel, chat_client
|
||||
):
|
||||
full_filename = get_absolute_path("tests/data/markdown/having_kids.markdown")
|
||||
new_agent = await AgentAdapters.aupdate_agent(
|
||||
|
@ -56,7 +56,7 @@ async def test_create_or_update_agent_with_knowledge_base(
|
|||
Agent.PrivacyLevel.PRIVATE,
|
||||
"icon",
|
||||
"color",
|
||||
default_openai_chat_model_option.chat_model,
|
||||
default_openai_chat_model_option.name,
|
||||
[full_filename],
|
||||
[],
|
||||
[],
|
||||
|
@ -78,7 +78,7 @@ async def test_create_or_update_agent_with_knowledge_base(
|
|||
@pytest.mark.anyio
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
async def test_create_or_update_agent_with_knowledge_base_and_search(
|
||||
default_user2: KhojUser, default_openai_chat_model_option: ChatModelOptions, chat_client
|
||||
default_user2: KhojUser, default_openai_chat_model_option: ChatModel, chat_client
|
||||
):
|
||||
full_filename = get_absolute_path("tests/data/markdown/having_kids.markdown")
|
||||
new_agent = await AgentAdapters.aupdate_agent(
|
||||
|
@ -88,7 +88,7 @@ async def test_create_or_update_agent_with_knowledge_base_and_search(
|
|||
Agent.PrivacyLevel.PRIVATE,
|
||||
"icon",
|
||||
"color",
|
||||
default_openai_chat_model_option.chat_model,
|
||||
default_openai_chat_model_option.name,
|
||||
[full_filename],
|
||||
[],
|
||||
[],
|
||||
|
@ -102,7 +102,7 @@ async def test_create_or_update_agent_with_knowledge_base_and_search(
|
|||
@pytest.mark.anyio
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
async def test_agent_with_knowledge_base_and_search_not_creator(
|
||||
default_user2: KhojUser, default_openai_chat_model_option: ChatModelOptions, chat_client, default_user3: KhojUser
|
||||
default_user2: KhojUser, default_openai_chat_model_option: ChatModel, chat_client, default_user3: KhojUser
|
||||
):
|
||||
full_filename = get_absolute_path("tests/data/markdown/having_kids.markdown")
|
||||
new_agent = await AgentAdapters.aupdate_agent(
|
||||
|
@ -112,7 +112,7 @@ async def test_agent_with_knowledge_base_and_search_not_creator(
|
|||
Agent.PrivacyLevel.PUBLIC,
|
||||
"icon",
|
||||
"color",
|
||||
default_openai_chat_model_option.chat_model,
|
||||
default_openai_chat_model_option.name,
|
||||
[full_filename],
|
||||
[],
|
||||
[],
|
||||
|
@ -126,7 +126,7 @@ async def test_agent_with_knowledge_base_and_search_not_creator(
|
|||
@pytest.mark.anyio
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
async def test_agent_with_knowledge_base_and_search_not_creator_and_private(
|
||||
default_user2: KhojUser, default_openai_chat_model_option: ChatModelOptions, chat_client, default_user3: KhojUser
|
||||
default_user2: KhojUser, default_openai_chat_model_option: ChatModel, chat_client, default_user3: KhojUser
|
||||
):
|
||||
full_filename = get_absolute_path("tests/data/markdown/having_kids.markdown")
|
||||
new_agent = await AgentAdapters.aupdate_agent(
|
||||
|
@ -136,7 +136,7 @@ async def test_agent_with_knowledge_base_and_search_not_creator_and_private(
|
|||
Agent.PrivacyLevel.PRIVATE,
|
||||
"icon",
|
||||
"color",
|
||||
default_openai_chat_model_option.chat_model,
|
||||
default_openai_chat_model_option.name,
|
||||
[full_filename],
|
||||
[],
|
||||
[],
|
||||
|
@ -150,7 +150,7 @@ async def test_agent_with_knowledge_base_and_search_not_creator_and_private(
|
|||
@pytest.mark.anyio
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
async def test_agent_with_knowledge_base_and_search_not_creator_and_private_accessible_to_none(
|
||||
default_user2: KhojUser, default_openai_chat_model_option: ChatModelOptions, chat_client
|
||||
default_user2: KhojUser, default_openai_chat_model_option: ChatModel, chat_client
|
||||
):
|
||||
full_filename = get_absolute_path("tests/data/markdown/having_kids.markdown")
|
||||
new_agent = await AgentAdapters.aupdate_agent(
|
||||
|
@ -160,7 +160,7 @@ async def test_agent_with_knowledge_base_and_search_not_creator_and_private_acce
|
|||
Agent.PrivacyLevel.PRIVATE,
|
||||
"icon",
|
||||
"color",
|
||||
default_openai_chat_model_option.chat_model,
|
||||
default_openai_chat_model_option.name,
|
||||
[full_filename],
|
||||
[],
|
||||
[],
|
||||
|
@ -174,7 +174,7 @@ async def test_agent_with_knowledge_base_and_search_not_creator_and_private_acce
|
|||
@pytest.mark.anyio
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
async def test_multiple_agents_with_knowledge_base_and_users(
|
||||
default_user2: KhojUser, default_openai_chat_model_option: ChatModelOptions, chat_client, default_user3: KhojUser
|
||||
default_user2: KhojUser, default_openai_chat_model_option: ChatModel, chat_client, default_user3: KhojUser
|
||||
):
|
||||
full_filename = get_absolute_path("tests/data/markdown/having_kids.markdown")
|
||||
new_agent = await AgentAdapters.aupdate_agent(
|
||||
|
@ -184,7 +184,7 @@ async def test_multiple_agents_with_knowledge_base_and_users(
|
|||
Agent.PrivacyLevel.PUBLIC,
|
||||
"icon",
|
||||
"color",
|
||||
default_openai_chat_model_option.chat_model,
|
||||
default_openai_chat_model_option.name,
|
||||
[full_filename],
|
||||
[],
|
||||
[],
|
||||
|
@ -198,7 +198,7 @@ async def test_multiple_agents_with_knowledge_base_and_users(
|
|||
Agent.PrivacyLevel.PUBLIC,
|
||||
"icon",
|
||||
"color",
|
||||
default_openai_chat_model_option.chat_model,
|
||||
default_openai_chat_model_option.name,
|
||||
[full_filename2],
|
||||
[],
|
||||
[],
|
||||
|
|
|
@ -2,12 +2,12 @@ from datetime import datetime
|
|||
|
||||
import pytest
|
||||
|
||||
from khoj.database.models import ChatModelOptions
|
||||
from khoj.database.models import ChatModel
|
||||
from khoj.routers.helpers import aget_data_sources_and_output_format
|
||||
from khoj.utils.helpers import ConversationCommand
|
||||
from tests.helpers import ConversationFactory, generate_chat_history, get_chat_provider
|
||||
|
||||
SKIP_TESTS = get_chat_provider(default=None) != ChatModelOptions.ModelType.OFFLINE
|
||||
SKIP_TESTS = get_chat_provider(default=None) != ChatModel.ModelType.OFFLINE
|
||||
pytestmark = pytest.mark.skipif(
|
||||
SKIP_TESTS,
|
||||
reason="Disable in CI to avoid long test runs.",
|
||||
|
|
|
@ -4,12 +4,12 @@ import pytest
|
|||
from faker import Faker
|
||||
from freezegun import freeze_time
|
||||
|
||||
from khoj.database.models import Agent, ChatModelOptions, Entry, KhojUser
|
||||
from khoj.database.models import Agent, ChatModel, Entry, KhojUser
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.utils import message_to_log
|
||||
from tests.helpers import ConversationFactory, get_chat_provider
|
||||
|
||||
SKIP_TESTS = get_chat_provider(default=None) != ChatModelOptions.ModelType.OFFLINE
|
||||
SKIP_TESTS = get_chat_provider(default=None) != ChatModel.ModelType.OFFLINE
|
||||
pytestmark = pytest.mark.skipif(
|
||||
SKIP_TESTS,
|
||||
reason="Disable in CI to avoid long test runs.",
|
||||
|
|
Loading…
Reference in a new issue