mirror of
https://github.com/khoj-ai/khoj.git
synced 2025-01-07 03:58:08 +00:00
Rename Chat Model Options table to Chat Model as short & readable (#1003)
- Previous was incorrectly plural but was defining only a single model - Rename chat model table field to name - Update documentation - Update references functions and variables to match new name
This commit is contained in:
parent
9be26e1bd2
commit
01bc6d35dc
26 changed files with 369 additions and 308 deletions
|
@ -25,7 +25,7 @@ Using LiteLLM with Khoj makes it possible to turn any LLM behind an API into you
|
||||||
- Name: `proxy-name`
|
- Name: `proxy-name`
|
||||||
- Api Key: `any string`
|
- Api Key: `any string`
|
||||||
- Api Base Url: **URL of your Openai Proxy API**
|
- Api Base Url: **URL of your Openai Proxy API**
|
||||||
4. Create a new [Chat Model Option](http://localhost:42110/server/admin/database/chatmodeloptions/add) on your Khoj admin panel.
|
4. Create a new [Chat Model](http://localhost:42110/server/admin/database/chatmodel/add) on your Khoj admin panel.
|
||||||
- Name: `llama3.1` (replace with the name of your local model)
|
- Name: `llama3.1` (replace with the name of your local model)
|
||||||
- Model Type: `Openai`
|
- Model Type: `Openai`
|
||||||
- Openai Config: `<the proxy config you created in step 3>`
|
- Openai Config: `<the proxy config you created in step 3>`
|
||||||
|
|
|
@ -18,7 +18,7 @@ LM Studio can expose an [OpenAI API compatible server](https://lmstudio.ai/docs/
|
||||||
- Name: `proxy-name`
|
- Name: `proxy-name`
|
||||||
- Api Key: `any string`
|
- Api Key: `any string`
|
||||||
- Api Base Url: `http://localhost:1234/v1/` (default for LMStudio)
|
- Api Base Url: `http://localhost:1234/v1/` (default for LMStudio)
|
||||||
4. Create a new [Chat Model Option](http://localhost:42110/server/admin/database/chatmodeloptions/add) on your Khoj admin panel.
|
4. Create a new [Chat Model](http://localhost:42110/server/admin/database/chatmodel/add) on your Khoj admin panel.
|
||||||
- Name: `llama3.1` (replace with the name of your local model)
|
- Name: `llama3.1` (replace with the name of your local model)
|
||||||
- Model Type: `Openai`
|
- Model Type: `Openai`
|
||||||
- Openai Config: `<the proxy config you created in step 3>`
|
- Openai Config: `<the proxy config you created in step 3>`
|
||||||
|
|
|
@ -64,7 +64,7 @@ Restart your Khoj server after first run or update to the settings below to ensu
|
||||||
- Name: `ollama`
|
- Name: `ollama`
|
||||||
- Api Key: `any string`
|
- Api Key: `any string`
|
||||||
- Api Base Url: `http://localhost:11434/v1/` (default for Ollama)
|
- Api Base Url: `http://localhost:11434/v1/` (default for Ollama)
|
||||||
4. Create a new [Chat Model Option](http://localhost:42110/server/admin/database/chatmodeloptions/add) on your Khoj admin panel.
|
4. Create a new [Chat Model](http://localhost:42110/server/admin/database/chatmodel/add) on your Khoj admin panel.
|
||||||
- Name: `llama3.1` (replace with the name of your local model)
|
- Name: `llama3.1` (replace with the name of your local model)
|
||||||
- Model Type: `Openai`
|
- Model Type: `Openai`
|
||||||
- Openai Config: `<the ollama config you created in step 3>`
|
- Openai Config: `<the ollama config you created in step 3>`
|
||||||
|
|
|
@ -25,7 +25,7 @@ For specific integrations, see our [Ollama](/advanced/ollama), [LMStudio](/advan
|
||||||
- Name: `any name`
|
- Name: `any name`
|
||||||
- Api Key: `any string`
|
- Api Key: `any string`
|
||||||
- Api Base Url: **URL of your Openai Proxy API**
|
- Api Base Url: **URL of your Openai Proxy API**
|
||||||
3. Create a new [Chat Model Option](http://localhost:42110/server/admin/database/chatmodeloptions/add) on your Khoj admin panel.
|
3. Create a new [Chat Model](http://localhost:42110/server/admin/database/chatmodel/add) on your Khoj admin panel.
|
||||||
- Name: `llama3` (replace with the name of your local model)
|
- Name: `llama3` (replace with the name of your local model)
|
||||||
- Model Type: `Openai`
|
- Model Type: `Openai`
|
||||||
- Openai Config: `<the proxy config you created in step 2>`
|
- Openai Config: `<the proxy config you created in step 2>`
|
||||||
|
|
|
@ -307,7 +307,7 @@ Using Ollama? See the [Ollama Integration](/advanced/ollama) section for more cu
|
||||||
- Give the configuration a friendly name like `OpenAI`
|
- Give the configuration a friendly name like `OpenAI`
|
||||||
- (Optional) Set the API base URL. It is only relevant if you're using another OpenAI-compatible proxy server like [Ollama](/advanced/ollama) or [LMStudio](/advanced/lmstudio).<br />
|
- (Optional) Set the API base URL. It is only relevant if you're using another OpenAI-compatible proxy server like [Ollama](/advanced/ollama) or [LMStudio](/advanced/lmstudio).<br />
|
||||||
![example configuration for ai model api](/img/example_openai_processor_config.png)
|
![example configuration for ai model api](/img/example_openai_processor_config.png)
|
||||||
2. Create a new [chat model options](http://localhost:42110/server/admin/database/chatmodeloptions/add)
|
2. Create a new [chat model](http://localhost:42110/server/admin/database/chatmodel/add)
|
||||||
- Set the `chat-model` field to an [OpenAI chat model](https://platform.openai.com/docs/models). Example: `gpt-4o`.
|
- Set the `chat-model` field to an [OpenAI chat model](https://platform.openai.com/docs/models). Example: `gpt-4o`.
|
||||||
- Make sure to set the `model-type` field to `OpenAI`.
|
- Make sure to set the `model-type` field to `OpenAI`.
|
||||||
- If your model supports vision, set the `vision enabled` field to `true`. This is currently only supported for OpenAI models with vision capabilities.
|
- If your model supports vision, set the `vision enabled` field to `true`. This is currently only supported for OpenAI models with vision capabilities.
|
||||||
|
@ -318,7 +318,7 @@ Using Ollama? See the [Ollama Integration](/advanced/ollama) section for more cu
|
||||||
1. Create a new [AI Model API](http://localhost:42110/server/admin/database/aimodelapi/add) in the server admin settings.
|
1. Create a new [AI Model API](http://localhost:42110/server/admin/database/aimodelapi/add) in the server admin settings.
|
||||||
- Add your [Anthropic API key](https://console.anthropic.com/account/keys)
|
- Add your [Anthropic API key](https://console.anthropic.com/account/keys)
|
||||||
- Give the configuration a friendly name like `Anthropic`. Do not configure the API base url.
|
- Give the configuration a friendly name like `Anthropic`. Do not configure the API base url.
|
||||||
2. Create a new [chat model options](http://localhost:42110/server/admin/database/chatmodeloptions/add)
|
2. Create a new [chat model](http://localhost:42110/server/admin/database/chatmodel/add)
|
||||||
- Set the `chat-model` field to an [Anthropic chat model](https://docs.anthropic.com/en/docs/about-claude/models#model-names). Example: `claude-3-5-sonnet-20240620`.
|
- Set the `chat-model` field to an [Anthropic chat model](https://docs.anthropic.com/en/docs/about-claude/models#model-names). Example: `claude-3-5-sonnet-20240620`.
|
||||||
- Set the `model-type` field to `Anthropic`.
|
- Set the `model-type` field to `Anthropic`.
|
||||||
- Set the `ai model api` field to the Anthropic AI Model API you created in step 1.
|
- Set the `ai model api` field to the Anthropic AI Model API you created in step 1.
|
||||||
|
@ -327,7 +327,7 @@ Using Ollama? See the [Ollama Integration](/advanced/ollama) section for more cu
|
||||||
1. Create a new [AI Model API](http://localhost:42110/server/admin/database/aimodelapi/add) in the server admin settings.
|
1. Create a new [AI Model API](http://localhost:42110/server/admin/database/aimodelapi/add) in the server admin settings.
|
||||||
- Add your [Gemini API key](https://aistudio.google.com/app/apikey)
|
- Add your [Gemini API key](https://aistudio.google.com/app/apikey)
|
||||||
- Give the configuration a friendly name like `Gemini`. Do not configure the API base url.
|
- Give the configuration a friendly name like `Gemini`. Do not configure the API base url.
|
||||||
2. Create a new [chat model options](http://localhost:42110/server/admin/database/chatmodeloptions/add)
|
2. Create a new [chat model](http://localhost:42110/server/admin/database/chatmodel/add)
|
||||||
- Set the `chat-model` field to a [Google Gemini chat model](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#gemini-models). Example: `gemini-1.5-flash`.
|
- Set the `chat-model` field to a [Google Gemini chat model](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#gemini-models). Example: `gemini-1.5-flash`.
|
||||||
- Set the `model-type` field to `Gemini`.
|
- Set the `model-type` field to `Gemini`.
|
||||||
- Set the `ai model api` field to the Gemini AI Model API you created in step 1.
|
- Set the `ai model api` field to the Gemini AI Model API you created in step 1.
|
||||||
|
@ -343,7 +343,7 @@ Offline chat stays completely private and can work without internet using any op
|
||||||
:::
|
:::
|
||||||
|
|
||||||
1. Get the name of your preferred chat model from [HuggingFace](https://huggingface.co/models?pipeline_tag=text-generation&library=gguf). *Most GGUF format chat models are supported*.
|
1. Get the name of your preferred chat model from [HuggingFace](https://huggingface.co/models?pipeline_tag=text-generation&library=gguf). *Most GGUF format chat models are supported*.
|
||||||
2. Open the [create chat model page](http://localhost:42110/server/admin/database/chatmodeloptions/add/) on the admin panel
|
2. Open the [create chat model page](http://localhost:42110/server/admin/database/chatmodel/add/) on the admin panel
|
||||||
3. Set the `chat-model` field to the name of your preferred chat model
|
3. Set the `chat-model` field to the name of your preferred chat model
|
||||||
- Make sure the `model-type` is set to `Offline`
|
- Make sure the `model-type` is set to `Offline`
|
||||||
4. Set the newly added chat model as your preferred model in your [User chat settings](http://localhost:42110/settings) and [Server chat settings](http://localhost:42110/server/admin/database/serverchatsettings/).
|
4. Set the newly added chat model as your preferred model in your [User chat settings](http://localhost:42110/settings) and [Server chat settings](http://localhost:42110/server/admin/database/serverchatsettings/).
|
||||||
|
|
|
@ -36,7 +36,7 @@ from torch import Tensor
|
||||||
from khoj.database.models import (
|
from khoj.database.models import (
|
||||||
Agent,
|
Agent,
|
||||||
AiModelApi,
|
AiModelApi,
|
||||||
ChatModelOptions,
|
ChatModel,
|
||||||
ClientApplication,
|
ClientApplication,
|
||||||
Conversation,
|
Conversation,
|
||||||
Entry,
|
Entry,
|
||||||
|
@ -736,8 +736,8 @@ class AgentAdapters:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_default_agent(user: KhojUser):
|
def create_default_agent(user: KhojUser):
|
||||||
default_conversation_config = ConversationAdapters.get_default_conversation_config(user)
|
default_chat_model = ConversationAdapters.get_default_chat_model(user)
|
||||||
if default_conversation_config is None:
|
if default_chat_model is None:
|
||||||
logger.info("No default conversation config found, skipping default agent creation")
|
logger.info("No default conversation config found, skipping default agent creation")
|
||||||
return None
|
return None
|
||||||
default_personality = prompts.personality.format(current_date="placeholder", day_of_week="placeholder")
|
default_personality = prompts.personality.format(current_date="placeholder", day_of_week="placeholder")
|
||||||
|
@ -746,7 +746,7 @@ class AgentAdapters:
|
||||||
|
|
||||||
if agent:
|
if agent:
|
||||||
agent.personality = default_personality
|
agent.personality = default_personality
|
||||||
agent.chat_model = default_conversation_config
|
agent.chat_model = default_chat_model
|
||||||
agent.slug = AgentAdapters.DEFAULT_AGENT_SLUG
|
agent.slug = AgentAdapters.DEFAULT_AGENT_SLUG
|
||||||
agent.name = AgentAdapters.DEFAULT_AGENT_NAME
|
agent.name = AgentAdapters.DEFAULT_AGENT_NAME
|
||||||
agent.privacy_level = Agent.PrivacyLevel.PUBLIC
|
agent.privacy_level = Agent.PrivacyLevel.PUBLIC
|
||||||
|
@ -760,7 +760,7 @@ class AgentAdapters:
|
||||||
name=AgentAdapters.DEFAULT_AGENT_NAME,
|
name=AgentAdapters.DEFAULT_AGENT_NAME,
|
||||||
privacy_level=Agent.PrivacyLevel.PUBLIC,
|
privacy_level=Agent.PrivacyLevel.PUBLIC,
|
||||||
managed_by_admin=True,
|
managed_by_admin=True,
|
||||||
chat_model=default_conversation_config,
|
chat_model=default_chat_model,
|
||||||
personality=default_personality,
|
personality=default_personality,
|
||||||
slug=AgentAdapters.DEFAULT_AGENT_SLUG,
|
slug=AgentAdapters.DEFAULT_AGENT_SLUG,
|
||||||
)
|
)
|
||||||
|
@ -787,7 +787,7 @@ class AgentAdapters:
|
||||||
output_modes: List[str],
|
output_modes: List[str],
|
||||||
slug: Optional[str] = None,
|
slug: Optional[str] = None,
|
||||||
):
|
):
|
||||||
chat_model_option = await ChatModelOptions.objects.filter(chat_model=chat_model).afirst()
|
chat_model_option = await ChatModel.objects.filter(name=chat_model).afirst()
|
||||||
|
|
||||||
# Slug will be None for new agents, which will trigger a new agent creation with a generated, immutable slug
|
# Slug will be None for new agents, which will trigger a new agent creation with a generated, immutable slug
|
||||||
agent, created = await Agent.objects.filter(slug=slug, creator=user).aupdate_or_create(
|
agent, created = await Agent.objects.filter(slug=slug, creator=user).aupdate_or_create(
|
||||||
|
@ -972,29 +972,29 @@ class ConversationAdapters:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@require_valid_user
|
@require_valid_user
|
||||||
def has_any_conversation_config(user: KhojUser):
|
def has_any_chat_model(user: KhojUser):
|
||||||
return ChatModelOptions.objects.filter(user=user).exists()
|
return ChatModel.objects.filter(user=user).exists()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_all_conversation_configs():
|
def get_all_chat_models():
|
||||||
return ChatModelOptions.objects.all()
|
return ChatModel.objects.all()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def aget_all_conversation_configs():
|
async def aget_all_chat_models():
|
||||||
return await sync_to_async(list)(ChatModelOptions.objects.prefetch_related("ai_model_api").all())
|
return await sync_to_async(list)(ChatModel.objects.prefetch_related("ai_model_api").all())
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_vision_enabled_config():
|
def get_vision_enabled_config():
|
||||||
conversation_configurations = ConversationAdapters.get_all_conversation_configs()
|
chat_models = ConversationAdapters.get_all_chat_models()
|
||||||
for config in conversation_configurations:
|
for config in chat_models:
|
||||||
if config.vision_enabled:
|
if config.vision_enabled:
|
||||||
return config
|
return config
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def aget_vision_enabled_config():
|
async def aget_vision_enabled_config():
|
||||||
conversation_configurations = await ConversationAdapters.aget_all_conversation_configs()
|
chat_models = await ConversationAdapters.aget_all_chat_models()
|
||||||
for config in conversation_configurations:
|
for config in chat_models:
|
||||||
if config.vision_enabled:
|
if config.vision_enabled:
|
||||||
return config
|
return config
|
||||||
return None
|
return None
|
||||||
|
@ -1010,7 +1010,7 @@ class ConversationAdapters:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@arequire_valid_user
|
@arequire_valid_user
|
||||||
async def aset_user_conversation_processor(user: KhojUser, conversation_processor_config_id: int):
|
async def aset_user_conversation_processor(user: KhojUser, conversation_processor_config_id: int):
|
||||||
config = await ChatModelOptions.objects.filter(id=conversation_processor_config_id).afirst()
|
config = await ChatModel.objects.filter(id=conversation_processor_config_id).afirst()
|
||||||
if not config:
|
if not config:
|
||||||
return None
|
return None
|
||||||
new_config = await UserConversationConfig.objects.aupdate_or_create(user=user, defaults={"setting": config})
|
new_config = await UserConversationConfig.objects.aupdate_or_create(user=user, defaults={"setting": config})
|
||||||
|
@ -1026,24 +1026,24 @@ class ConversationAdapters:
|
||||||
return new_config
|
return new_config
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_conversation_config(user: KhojUser):
|
def get_chat_model(user: KhojUser):
|
||||||
subscribed = is_user_subscribed(user)
|
subscribed = is_user_subscribed(user)
|
||||||
if not subscribed:
|
if not subscribed:
|
||||||
return ConversationAdapters.get_default_conversation_config(user)
|
return ConversationAdapters.get_default_chat_model(user)
|
||||||
config = UserConversationConfig.objects.filter(user=user).first()
|
config = UserConversationConfig.objects.filter(user=user).first()
|
||||||
if config:
|
if config:
|
||||||
return config.setting
|
return config.setting
|
||||||
return ConversationAdapters.get_advanced_conversation_config(user)
|
return ConversationAdapters.get_advanced_chat_model(user)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def aget_conversation_config(user: KhojUser):
|
async def aget_chat_model(user: KhojUser):
|
||||||
subscribed = await ais_user_subscribed(user)
|
subscribed = await ais_user_subscribed(user)
|
||||||
if not subscribed:
|
if not subscribed:
|
||||||
return await ConversationAdapters.aget_default_conversation_config(user)
|
return await ConversationAdapters.aget_default_chat_model(user)
|
||||||
config = await UserConversationConfig.objects.filter(user=user).prefetch_related("setting").afirst()
|
config = await UserConversationConfig.objects.filter(user=user).prefetch_related("setting").afirst()
|
||||||
if config:
|
if config:
|
||||||
return config.setting
|
return config.setting
|
||||||
return ConversationAdapters.aget_advanced_conversation_config(user)
|
return ConversationAdapters.aget_advanced_chat_model(user)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def aget_voice_model_config(user: KhojUser) -> Optional[VoiceModelOption]:
|
async def aget_voice_model_config(user: KhojUser) -> Optional[VoiceModelOption]:
|
||||||
|
@ -1064,7 +1064,7 @@ class ConversationAdapters:
|
||||||
return VoiceModelOption.objects.first()
|
return VoiceModelOption.objects.first()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_default_conversation_config(user: KhojUser = None):
|
def get_default_chat_model(user: KhojUser = None):
|
||||||
"""Get default conversation config. Prefer chat model by server admin > user > first created chat model"""
|
"""Get default conversation config. Prefer chat model by server admin > user > first created chat model"""
|
||||||
# Get the server chat settings
|
# Get the server chat settings
|
||||||
server_chat_settings = ServerChatSettings.objects.first()
|
server_chat_settings = ServerChatSettings.objects.first()
|
||||||
|
@ -1084,10 +1084,10 @@ class ConversationAdapters:
|
||||||
return user_chat_settings.setting
|
return user_chat_settings.setting
|
||||||
|
|
||||||
# Get the first chat model if even the user chat settings are not set
|
# Get the first chat model if even the user chat settings are not set
|
||||||
return ChatModelOptions.objects.filter().first()
|
return ChatModel.objects.filter().first()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def aget_default_conversation_config(user: KhojUser = None):
|
async def aget_default_chat_model(user: KhojUser = None):
|
||||||
"""Get default conversation config. Prefer chat model by server admin > user > first created chat model"""
|
"""Get default conversation config. Prefer chat model by server admin > user > first created chat model"""
|
||||||
# Get the server chat settings
|
# Get the server chat settings
|
||||||
server_chat_settings: ServerChatSettings = (
|
server_chat_settings: ServerChatSettings = (
|
||||||
|
@ -1117,17 +1117,17 @@ class ConversationAdapters:
|
||||||
return user_chat_settings.setting
|
return user_chat_settings.setting
|
||||||
|
|
||||||
# Get the first chat model if even the user chat settings are not set
|
# Get the first chat model if even the user chat settings are not set
|
||||||
return await ChatModelOptions.objects.filter().prefetch_related("ai_model_api").afirst()
|
return await ChatModel.objects.filter().prefetch_related("ai_model_api").afirst()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_advanced_conversation_config(user: KhojUser):
|
def get_advanced_chat_model(user: KhojUser):
|
||||||
server_chat_settings = ServerChatSettings.objects.first()
|
server_chat_settings = ServerChatSettings.objects.first()
|
||||||
if server_chat_settings is not None and server_chat_settings.chat_advanced is not None:
|
if server_chat_settings is not None and server_chat_settings.chat_advanced is not None:
|
||||||
return server_chat_settings.chat_advanced
|
return server_chat_settings.chat_advanced
|
||||||
return ConversationAdapters.get_default_conversation_config(user)
|
return ConversationAdapters.get_default_chat_model(user)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def aget_advanced_conversation_config(user: KhojUser = None):
|
async def aget_advanced_chat_model(user: KhojUser = None):
|
||||||
server_chat_settings: ServerChatSettings = (
|
server_chat_settings: ServerChatSettings = (
|
||||||
await ServerChatSettings.objects.filter()
|
await ServerChatSettings.objects.filter()
|
||||||
.prefetch_related("chat_advanced", "chat_advanced__ai_model_api")
|
.prefetch_related("chat_advanced", "chat_advanced__ai_model_api")
|
||||||
|
@ -1135,7 +1135,7 @@ class ConversationAdapters:
|
||||||
)
|
)
|
||||||
if server_chat_settings is not None and server_chat_settings.chat_advanced is not None:
|
if server_chat_settings is not None and server_chat_settings.chat_advanced is not None:
|
||||||
return server_chat_settings.chat_advanced
|
return server_chat_settings.chat_advanced
|
||||||
return await ConversationAdapters.aget_default_conversation_config(user)
|
return await ConversationAdapters.aget_default_chat_model(user)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def aget_server_webscraper():
|
async def aget_server_webscraper():
|
||||||
|
@ -1247,16 +1247,16 @@ class ConversationAdapters:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_conversation_processor_options():
|
def get_conversation_processor_options():
|
||||||
return ChatModelOptions.objects.all()
|
return ChatModel.objects.all()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def set_conversation_processor_config(user: KhojUser, new_config: ChatModelOptions):
|
def set_user_chat_model(user: KhojUser, chat_model: ChatModel):
|
||||||
user_conversation_config, _ = UserConversationConfig.objects.get_or_create(user=user)
|
user_conversation_config, _ = UserConversationConfig.objects.get_or_create(user=user)
|
||||||
user_conversation_config.setting = new_config
|
user_conversation_config.setting = chat_model
|
||||||
user_conversation_config.save()
|
user_conversation_config.save()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def aget_user_conversation_config(user: KhojUser):
|
async def aget_user_chat_model(user: KhojUser):
|
||||||
config = (
|
config = (
|
||||||
await UserConversationConfig.objects.filter(user=user).prefetch_related("setting__ai_model_api").afirst()
|
await UserConversationConfig.objects.filter(user=user).prefetch_related("setting__ai_model_api").afirst()
|
||||||
)
|
)
|
||||||
|
@ -1288,33 +1288,33 @@ class ConversationAdapters:
|
||||||
return random.sample(all_questions, max_results)
|
return random.sample(all_questions, max_results)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_valid_conversation_config(user: KhojUser, conversation: Conversation):
|
def get_valid_chat_model(user: KhojUser, conversation: Conversation):
|
||||||
agent: Agent = conversation.agent if AgentAdapters.get_default_agent() != conversation.agent else None
|
agent: Agent = conversation.agent if AgentAdapters.get_default_agent() != conversation.agent else None
|
||||||
if agent and agent.chat_model:
|
if agent and agent.chat_model:
|
||||||
conversation_config = conversation.agent.chat_model
|
chat_model = conversation.agent.chat_model
|
||||||
else:
|
else:
|
||||||
conversation_config = ConversationAdapters.get_conversation_config(user)
|
chat_model = ConversationAdapters.get_chat_model(user)
|
||||||
|
|
||||||
if conversation_config is None:
|
if chat_model is None:
|
||||||
conversation_config = ConversationAdapters.get_default_conversation_config()
|
chat_model = ConversationAdapters.get_default_chat_model()
|
||||||
|
|
||||||
if conversation_config.model_type == ChatModelOptions.ModelType.OFFLINE:
|
if chat_model.model_type == ChatModel.ModelType.OFFLINE:
|
||||||
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
|
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
|
||||||
chat_model = conversation_config.chat_model
|
chat_model_name = chat_model.name
|
||||||
max_tokens = conversation_config.max_prompt_size
|
max_tokens = chat_model.max_prompt_size
|
||||||
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
|
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model_name, max_tokens)
|
||||||
|
|
||||||
return conversation_config
|
return chat_model
|
||||||
|
|
||||||
if (
|
if (
|
||||||
conversation_config.model_type
|
chat_model.model_type
|
||||||
in [
|
in [
|
||||||
ChatModelOptions.ModelType.ANTHROPIC,
|
ChatModel.ModelType.ANTHROPIC,
|
||||||
ChatModelOptions.ModelType.OPENAI,
|
ChatModel.ModelType.OPENAI,
|
||||||
ChatModelOptions.ModelType.GOOGLE,
|
ChatModel.ModelType.GOOGLE,
|
||||||
]
|
]
|
||||||
) and conversation_config.ai_model_api:
|
) and chat_model.ai_model_api:
|
||||||
return conversation_config
|
return chat_model
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid conversation config - either configure offline chat or openai chat")
|
raise ValueError("Invalid conversation config - either configure offline chat or openai chat")
|
||||||
|
|
|
@ -16,7 +16,7 @@ from unfold import admin as unfold_admin
|
||||||
from khoj.database.models import (
|
from khoj.database.models import (
|
||||||
Agent,
|
Agent,
|
||||||
AiModelApi,
|
AiModelApi,
|
||||||
ChatModelOptions,
|
ChatModel,
|
||||||
ClientApplication,
|
ClientApplication,
|
||||||
Conversation,
|
Conversation,
|
||||||
Entry,
|
Entry,
|
||||||
|
@ -212,15 +212,15 @@ class KhojUserSubscription(unfold_admin.ModelAdmin):
|
||||||
list_filter = ("type",)
|
list_filter = ("type",)
|
||||||
|
|
||||||
|
|
||||||
@admin.register(ChatModelOptions)
|
@admin.register(ChatModel)
|
||||||
class ChatModelOptionsAdmin(unfold_admin.ModelAdmin):
|
class ChatModelAdmin(unfold_admin.ModelAdmin):
|
||||||
list_display = (
|
list_display = (
|
||||||
"id",
|
"id",
|
||||||
"chat_model",
|
"name",
|
||||||
"ai_model_api",
|
"ai_model_api",
|
||||||
"max_prompt_size",
|
"max_prompt_size",
|
||||||
)
|
)
|
||||||
search_fields = ("id", "chat_model", "ai_model_api__name")
|
search_fields = ("id", "name", "ai_model_api__name")
|
||||||
|
|
||||||
|
|
||||||
@admin.register(TextToImageModelConfig)
|
@admin.register(TextToImageModelConfig)
|
||||||
|
@ -385,7 +385,7 @@ class UserConversationConfigAdmin(unfold_admin.ModelAdmin):
|
||||||
"get_chat_model",
|
"get_chat_model",
|
||||||
"get_subscription_type",
|
"get_subscription_type",
|
||||||
)
|
)
|
||||||
search_fields = ("id", "user__email", "setting__chat_model", "user__subscription__type")
|
search_fields = ("id", "user__email", "setting__name", "user__subscription__type")
|
||||||
ordering = ("-updated_at",)
|
ordering = ("-updated_at",)
|
||||||
|
|
||||||
def get_user_email(self, obj):
|
def get_user_email(self, obj):
|
||||||
|
@ -395,10 +395,10 @@ class UserConversationConfigAdmin(unfold_admin.ModelAdmin):
|
||||||
get_user_email.admin_order_field = "user__email" # type: ignore
|
get_user_email.admin_order_field = "user__email" # type: ignore
|
||||||
|
|
||||||
def get_chat_model(self, obj):
|
def get_chat_model(self, obj):
|
||||||
return obj.setting.chat_model if obj.setting else None
|
return obj.setting.name if obj.setting else None
|
||||||
|
|
||||||
get_chat_model.short_description = "Chat Model" # type: ignore
|
get_chat_model.short_description = "Chat Model" # type: ignore
|
||||||
get_chat_model.admin_order_field = "setting__chat_model" # type: ignore
|
get_chat_model.admin_order_field = "setting__name" # type: ignore
|
||||||
|
|
||||||
def get_subscription_type(self, obj):
|
def get_subscription_type(self, obj):
|
||||||
if hasattr(obj.user, "subscription"):
|
if hasattr(obj.user, "subscription"):
|
||||||
|
|
|
@ -0,0 +1,62 @@
|
||||||
|
# Generated by Django 5.0.9 on 2024-12-09 04:21
|
||||||
|
|
||||||
|
import django.db.models.deletion
|
||||||
|
from django.db import migrations, models
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
dependencies = [
|
||||||
|
("database", "0076_rename_openaiprocessorconversationconfig_aimodelapi_and_more"),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.RenameModel(
|
||||||
|
old_name="ChatModelOptions",
|
||||||
|
new_name="ChatModel",
|
||||||
|
),
|
||||||
|
migrations.RenameField(
|
||||||
|
model_name="chatmodel",
|
||||||
|
old_name="chat_model",
|
||||||
|
new_name="name",
|
||||||
|
),
|
||||||
|
migrations.AlterField(
|
||||||
|
model_name="agent",
|
||||||
|
name="chat_model",
|
||||||
|
field=models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to="database.chatmodel"),
|
||||||
|
),
|
||||||
|
migrations.AlterField(
|
||||||
|
model_name="serverchatsettings",
|
||||||
|
name="chat_advanced",
|
||||||
|
field=models.ForeignKey(
|
||||||
|
blank=True,
|
||||||
|
default=None,
|
||||||
|
null=True,
|
||||||
|
on_delete=django.db.models.deletion.CASCADE,
|
||||||
|
related_name="chat_advanced",
|
||||||
|
to="database.chatmodel",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
migrations.AlterField(
|
||||||
|
model_name="serverchatsettings",
|
||||||
|
name="chat_default",
|
||||||
|
field=models.ForeignKey(
|
||||||
|
blank=True,
|
||||||
|
default=None,
|
||||||
|
null=True,
|
||||||
|
on_delete=django.db.models.deletion.CASCADE,
|
||||||
|
related_name="chat_default",
|
||||||
|
to="database.chatmodel",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
migrations.AlterField(
|
||||||
|
model_name="userconversationconfig",
|
||||||
|
name="setting",
|
||||||
|
field=models.ForeignKey(
|
||||||
|
blank=True,
|
||||||
|
default=None,
|
||||||
|
null=True,
|
||||||
|
on_delete=django.db.models.deletion.CASCADE,
|
||||||
|
to="database.chatmodel",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]
|
|
@ -193,7 +193,7 @@ class AiModelApi(DbBaseModel):
|
||||||
return self.name
|
return self.name
|
||||||
|
|
||||||
|
|
||||||
class ChatModelOptions(DbBaseModel):
|
class ChatModel(DbBaseModel):
|
||||||
class ModelType(models.TextChoices):
|
class ModelType(models.TextChoices):
|
||||||
OPENAI = "openai"
|
OPENAI = "openai"
|
||||||
OFFLINE = "offline"
|
OFFLINE = "offline"
|
||||||
|
@ -203,13 +203,13 @@ class ChatModelOptions(DbBaseModel):
|
||||||
max_prompt_size = models.IntegerField(default=None, null=True, blank=True)
|
max_prompt_size = models.IntegerField(default=None, null=True, blank=True)
|
||||||
subscribed_max_prompt_size = models.IntegerField(default=None, null=True, blank=True)
|
subscribed_max_prompt_size = models.IntegerField(default=None, null=True, blank=True)
|
||||||
tokenizer = models.CharField(max_length=200, default=None, null=True, blank=True)
|
tokenizer = models.CharField(max_length=200, default=None, null=True, blank=True)
|
||||||
chat_model = models.CharField(max_length=200, default="bartowski/Meta-Llama-3.1-8B-Instruct-GGUF")
|
name = models.CharField(max_length=200, default="bartowski/Meta-Llama-3.1-8B-Instruct-GGUF")
|
||||||
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE)
|
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE)
|
||||||
vision_enabled = models.BooleanField(default=False)
|
vision_enabled = models.BooleanField(default=False)
|
||||||
ai_model_api = models.ForeignKey(AiModelApi, on_delete=models.CASCADE, default=None, null=True, blank=True)
|
ai_model_api = models.ForeignKey(AiModelApi, on_delete=models.CASCADE, default=None, null=True, blank=True)
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return self.chat_model
|
return self.name
|
||||||
|
|
||||||
|
|
||||||
class VoiceModelOption(DbBaseModel):
|
class VoiceModelOption(DbBaseModel):
|
||||||
|
@ -297,7 +297,7 @@ class Agent(DbBaseModel):
|
||||||
models.CharField(max_length=200, choices=OutputModeOptions.choices), default=list, null=True, blank=True
|
models.CharField(max_length=200, choices=OutputModeOptions.choices), default=list, null=True, blank=True
|
||||||
)
|
)
|
||||||
managed_by_admin = models.BooleanField(default=False)
|
managed_by_admin = models.BooleanField(default=False)
|
||||||
chat_model = models.ForeignKey(ChatModelOptions, on_delete=models.CASCADE)
|
chat_model = models.ForeignKey(ChatModel, on_delete=models.CASCADE)
|
||||||
slug = models.CharField(max_length=200, unique=True)
|
slug = models.CharField(max_length=200, unique=True)
|
||||||
style_color = models.CharField(max_length=200, choices=StyleColorTypes.choices, default=StyleColorTypes.BLUE)
|
style_color = models.CharField(max_length=200, choices=StyleColorTypes.choices, default=StyleColorTypes.BLUE)
|
||||||
style_icon = models.CharField(max_length=200, choices=StyleIconTypes.choices, default=StyleIconTypes.LIGHTBULB)
|
style_icon = models.CharField(max_length=200, choices=StyleIconTypes.choices, default=StyleIconTypes.LIGHTBULB)
|
||||||
|
@ -438,10 +438,10 @@ class WebScraper(DbBaseModel):
|
||||||
|
|
||||||
class ServerChatSettings(DbBaseModel):
|
class ServerChatSettings(DbBaseModel):
|
||||||
chat_default = models.ForeignKey(
|
chat_default = models.ForeignKey(
|
||||||
ChatModelOptions, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="chat_default"
|
ChatModel, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="chat_default"
|
||||||
)
|
)
|
||||||
chat_advanced = models.ForeignKey(
|
chat_advanced = models.ForeignKey(
|
||||||
ChatModelOptions, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="chat_advanced"
|
ChatModel, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="chat_advanced"
|
||||||
)
|
)
|
||||||
web_scraper = models.ForeignKey(
|
web_scraper = models.ForeignKey(
|
||||||
WebScraper, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="web_scraper"
|
WebScraper, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="web_scraper"
|
||||||
|
@ -563,7 +563,7 @@ class SpeechToTextModelOptions(DbBaseModel):
|
||||||
|
|
||||||
class UserConversationConfig(DbBaseModel):
|
class UserConversationConfig(DbBaseModel):
|
||||||
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
|
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
|
||||||
setting = models.ForeignKey(ChatModelOptions, on_delete=models.CASCADE, default=None, null=True, blank=True)
|
setting = models.ForeignKey(ChatModel, on_delete=models.CASCADE, default=None, null=True, blank=True)
|
||||||
|
|
||||||
|
|
||||||
class UserVoiceModelConfig(DbBaseModel):
|
class UserVoiceModelConfig(DbBaseModel):
|
||||||
|
|
|
@ -60,7 +60,7 @@ import logging
|
||||||
|
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
from khoj.database.models import AiModelApi, ChatModelOptions, SearchModelConfig
|
from khoj.database.models import AiModelApi, ChatModel, SearchModelConfig
|
||||||
from khoj.utils.yaml import load_config_from_file, save_config_to_file
|
from khoj.utils.yaml import load_config_from_file, save_config_to_file
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -98,11 +98,11 @@ def migrate_server_pg(args):
|
||||||
|
|
||||||
if "offline-chat" in raw_config["processor"]["conversation"]:
|
if "offline-chat" in raw_config["processor"]["conversation"]:
|
||||||
offline_chat = raw_config["processor"]["conversation"]["offline-chat"]
|
offline_chat = raw_config["processor"]["conversation"]["offline-chat"]
|
||||||
ChatModelOptions.objects.create(
|
ChatModel.objects.create(
|
||||||
chat_model=offline_chat.get("chat-model"),
|
name=offline_chat.get("chat-model"),
|
||||||
tokenizer=processor_conversation.get("tokenizer"),
|
tokenizer=processor_conversation.get("tokenizer"),
|
||||||
max_prompt_size=processor_conversation.get("max-prompt-size"),
|
max_prompt_size=processor_conversation.get("max-prompt-size"),
|
||||||
model_type=ChatModelOptions.ModelType.OFFLINE,
|
model_type=ChatModel.ModelType.OFFLINE,
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
|
@ -119,11 +119,11 @@ def migrate_server_pg(args):
|
||||||
|
|
||||||
openai_model_api = AiModelApi.objects.create(api_key=openai.get("api-key"), name="default")
|
openai_model_api = AiModelApi.objects.create(api_key=openai.get("api-key"), name="default")
|
||||||
|
|
||||||
ChatModelOptions.objects.create(
|
ChatModel.objects.create(
|
||||||
chat_model=openai.get("chat-model"),
|
name=openai.get("chat-model"),
|
||||||
tokenizer=processor_conversation.get("tokenizer"),
|
tokenizer=processor_conversation.get("tokenizer"),
|
||||||
max_prompt_size=processor_conversation.get("max-prompt-size"),
|
max_prompt_size=processor_conversation.get("max-prompt-size"),
|
||||||
model_type=ChatModelOptions.ModelType.OPENAI,
|
model_type=ChatModel.ModelType.OPENAI,
|
||||||
ai_model_api=openai_model_api,
|
ai_model_api=openai_model_api,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,7 @@ from typing import Dict, List, Optional
|
||||||
import pyjson5
|
import pyjson5
|
||||||
from langchain.schema import ChatMessage
|
from langchain.schema import ChatMessage
|
||||||
|
|
||||||
from khoj.database.models import Agent, ChatModelOptions, KhojUser
|
from khoj.database.models import Agent, ChatModel, KhojUser
|
||||||
from khoj.processor.conversation import prompts
|
from khoj.processor.conversation import prompts
|
||||||
from khoj.processor.conversation.anthropic.utils import (
|
from khoj.processor.conversation.anthropic.utils import (
|
||||||
anthropic_chat_completion_with_backoff,
|
anthropic_chat_completion_with_backoff,
|
||||||
|
@ -85,7 +85,7 @@ def extract_questions_anthropic(
|
||||||
prompt = construct_structured_message(
|
prompt = construct_structured_message(
|
||||||
message=prompt,
|
message=prompt,
|
||||||
images=query_images,
|
images=query_images,
|
||||||
model_type=ChatModelOptions.ModelType.ANTHROPIC,
|
model_type=ChatModel.ModelType.ANTHROPIC,
|
||||||
vision_enabled=vision_enabled,
|
vision_enabled=vision_enabled,
|
||||||
attached_file_context=query_files,
|
attached_file_context=query_files,
|
||||||
)
|
)
|
||||||
|
@ -218,7 +218,7 @@ def converse_anthropic(
|
||||||
tokenizer_name=tokenizer_name,
|
tokenizer_name=tokenizer_name,
|
||||||
query_images=query_images,
|
query_images=query_images,
|
||||||
vision_enabled=vision_available,
|
vision_enabled=vision_available,
|
||||||
model_type=ChatModelOptions.ModelType.ANTHROPIC,
|
model_type=ChatModel.ModelType.ANTHROPIC,
|
||||||
query_files=query_files,
|
query_files=query_files,
|
||||||
generated_files=generated_files,
|
generated_files=generated_files,
|
||||||
generated_asset_results=generated_asset_results,
|
generated_asset_results=generated_asset_results,
|
||||||
|
|
|
@ -5,7 +5,7 @@ from typing import Dict, List, Optional
|
||||||
import pyjson5
|
import pyjson5
|
||||||
from langchain.schema import ChatMessage
|
from langchain.schema import ChatMessage
|
||||||
|
|
||||||
from khoj.database.models import Agent, ChatModelOptions, KhojUser
|
from khoj.database.models import Agent, ChatModel, KhojUser
|
||||||
from khoj.processor.conversation import prompts
|
from khoj.processor.conversation import prompts
|
||||||
from khoj.processor.conversation.google.utils import (
|
from khoj.processor.conversation.google.utils import (
|
||||||
format_messages_for_gemini,
|
format_messages_for_gemini,
|
||||||
|
@ -86,7 +86,7 @@ def extract_questions_gemini(
|
||||||
prompt = construct_structured_message(
|
prompt = construct_structured_message(
|
||||||
message=prompt,
|
message=prompt,
|
||||||
images=query_images,
|
images=query_images,
|
||||||
model_type=ChatModelOptions.ModelType.GOOGLE,
|
model_type=ChatModel.ModelType.GOOGLE,
|
||||||
vision_enabled=vision_enabled,
|
vision_enabled=vision_enabled,
|
||||||
attached_file_context=query_files,
|
attached_file_context=query_files,
|
||||||
)
|
)
|
||||||
|
@ -229,7 +229,7 @@ def converse_gemini(
|
||||||
tokenizer_name=tokenizer_name,
|
tokenizer_name=tokenizer_name,
|
||||||
query_images=query_images,
|
query_images=query_images,
|
||||||
vision_enabled=vision_available,
|
vision_enabled=vision_available,
|
||||||
model_type=ChatModelOptions.ModelType.GOOGLE,
|
model_type=ChatModel.ModelType.GOOGLE,
|
||||||
query_files=query_files,
|
query_files=query_files,
|
||||||
generated_files=generated_files,
|
generated_files=generated_files,
|
||||||
generated_asset_results=generated_asset_results,
|
generated_asset_results=generated_asset_results,
|
||||||
|
|
|
@ -9,7 +9,7 @@ import pyjson5
|
||||||
from langchain.schema import ChatMessage
|
from langchain.schema import ChatMessage
|
||||||
from llama_cpp import Llama
|
from llama_cpp import Llama
|
||||||
|
|
||||||
from khoj.database.models import Agent, ChatModelOptions, KhojUser
|
from khoj.database.models import Agent, ChatModel, KhojUser
|
||||||
from khoj.processor.conversation import prompts
|
from khoj.processor.conversation import prompts
|
||||||
from khoj.processor.conversation.offline.utils import download_model
|
from khoj.processor.conversation.offline.utils import download_model
|
||||||
from khoj.processor.conversation.utils import (
|
from khoj.processor.conversation.utils import (
|
||||||
|
@ -96,7 +96,7 @@ def extract_questions_offline(
|
||||||
model_name=model,
|
model_name=model,
|
||||||
loaded_model=offline_chat_model,
|
loaded_model=offline_chat_model,
|
||||||
max_prompt_size=max_prompt_size,
|
max_prompt_size=max_prompt_size,
|
||||||
model_type=ChatModelOptions.ModelType.OFFLINE,
|
model_type=ChatModel.ModelType.OFFLINE,
|
||||||
query_files=query_files,
|
query_files=query_files,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -232,7 +232,7 @@ def converse_offline(
|
||||||
loaded_model=offline_chat_model,
|
loaded_model=offline_chat_model,
|
||||||
max_prompt_size=max_prompt_size,
|
max_prompt_size=max_prompt_size,
|
||||||
tokenizer_name=tokenizer_name,
|
tokenizer_name=tokenizer_name,
|
||||||
model_type=ChatModelOptions.ModelType.OFFLINE,
|
model_type=ChatModel.ModelType.OFFLINE,
|
||||||
query_files=query_files,
|
query_files=query_files,
|
||||||
generated_files=generated_files,
|
generated_files=generated_files,
|
||||||
generated_asset_results=generated_asset_results,
|
generated_asset_results=generated_asset_results,
|
||||||
|
|
|
@ -5,7 +5,7 @@ from typing import Dict, List, Optional
|
||||||
import pyjson5
|
import pyjson5
|
||||||
from langchain.schema import ChatMessage
|
from langchain.schema import ChatMessage
|
||||||
|
|
||||||
from khoj.database.models import Agent, ChatModelOptions, KhojUser
|
from khoj.database.models import Agent, ChatModel, KhojUser
|
||||||
from khoj.processor.conversation import prompts
|
from khoj.processor.conversation import prompts
|
||||||
from khoj.processor.conversation.openai.utils import (
|
from khoj.processor.conversation.openai.utils import (
|
||||||
chat_completion_with_backoff,
|
chat_completion_with_backoff,
|
||||||
|
@ -83,7 +83,7 @@ def extract_questions(
|
||||||
prompt = construct_structured_message(
|
prompt = construct_structured_message(
|
||||||
message=prompt,
|
message=prompt,
|
||||||
images=query_images,
|
images=query_images,
|
||||||
model_type=ChatModelOptions.ModelType.OPENAI,
|
model_type=ChatModel.ModelType.OPENAI,
|
||||||
vision_enabled=vision_enabled,
|
vision_enabled=vision_enabled,
|
||||||
attached_file_context=query_files,
|
attached_file_context=query_files,
|
||||||
)
|
)
|
||||||
|
@ -220,7 +220,7 @@ def converse_openai(
|
||||||
tokenizer_name=tokenizer_name,
|
tokenizer_name=tokenizer_name,
|
||||||
query_images=query_images,
|
query_images=query_images,
|
||||||
vision_enabled=vision_available,
|
vision_enabled=vision_available,
|
||||||
model_type=ChatModelOptions.ModelType.OPENAI,
|
model_type=ChatModel.ModelType.OPENAI,
|
||||||
query_files=query_files,
|
query_files=query_files,
|
||||||
generated_files=generated_files,
|
generated_files=generated_files,
|
||||||
generated_asset_results=generated_asset_results,
|
generated_asset_results=generated_asset_results,
|
||||||
|
|
|
@ -24,7 +24,7 @@ from llama_cpp.llama import Llama
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
from khoj.database.adapters import ConversationAdapters
|
from khoj.database.adapters import ConversationAdapters
|
||||||
from khoj.database.models import ChatModelOptions, ClientApplication, KhojUser
|
from khoj.database.models import ChatModel, ClientApplication, KhojUser
|
||||||
from khoj.processor.conversation import prompts
|
from khoj.processor.conversation import prompts
|
||||||
from khoj.processor.conversation.offline.utils import download_model, infer_max_tokens
|
from khoj.processor.conversation.offline.utils import download_model, infer_max_tokens
|
||||||
from khoj.search_filter.base_filter import BaseFilter
|
from khoj.search_filter.base_filter import BaseFilter
|
||||||
|
@ -330,9 +330,9 @@ def construct_structured_message(
|
||||||
Format messages into appropriate multimedia format for supported chat model types
|
Format messages into appropriate multimedia format for supported chat model types
|
||||||
"""
|
"""
|
||||||
if model_type in [
|
if model_type in [
|
||||||
ChatModelOptions.ModelType.OPENAI,
|
ChatModel.ModelType.OPENAI,
|
||||||
ChatModelOptions.ModelType.GOOGLE,
|
ChatModel.ModelType.GOOGLE,
|
||||||
ChatModelOptions.ModelType.ANTHROPIC,
|
ChatModel.ModelType.ANTHROPIC,
|
||||||
]:
|
]:
|
||||||
if not attached_file_context and not (vision_enabled and images):
|
if not attached_file_context and not (vision_enabled and images):
|
||||||
return message
|
return message
|
||||||
|
|
|
@ -28,12 +28,7 @@ from khoj.database.adapters import (
|
||||||
get_default_search_model,
|
get_default_search_model,
|
||||||
get_user_photo,
|
get_user_photo,
|
||||||
)
|
)
|
||||||
from khoj.database.models import (
|
from khoj.database.models import Agent, ChatModel, KhojUser, SpeechToTextModelOptions
|
||||||
Agent,
|
|
||||||
ChatModelOptions,
|
|
||||||
KhojUser,
|
|
||||||
SpeechToTextModelOptions,
|
|
||||||
)
|
|
||||||
from khoj.processor.conversation import prompts
|
from khoj.processor.conversation import prompts
|
||||||
from khoj.processor.conversation.anthropic.anthropic_chat import (
|
from khoj.processor.conversation.anthropic.anthropic_chat import (
|
||||||
extract_questions_anthropic,
|
extract_questions_anthropic,
|
||||||
|
@ -404,15 +399,15 @@ async def extract_references_and_questions(
|
||||||
# Infer search queries from user message
|
# Infer search queries from user message
|
||||||
with timer("Extracting search queries took", logger):
|
with timer("Extracting search queries took", logger):
|
||||||
# If we've reached here, either the user has enabled offline chat or the openai model is enabled.
|
# If we've reached here, either the user has enabled offline chat or the openai model is enabled.
|
||||||
conversation_config = await ConversationAdapters.aget_default_conversation_config(user)
|
chat_model = await ConversationAdapters.aget_default_chat_model(user)
|
||||||
vision_enabled = conversation_config.vision_enabled
|
vision_enabled = chat_model.vision_enabled
|
||||||
|
|
||||||
if conversation_config.model_type == ChatModelOptions.ModelType.OFFLINE:
|
if chat_model.model_type == ChatModel.ModelType.OFFLINE:
|
||||||
using_offline_chat = True
|
using_offline_chat = True
|
||||||
chat_model = conversation_config.chat_model
|
chat_model_name = chat_model.name
|
||||||
max_tokens = conversation_config.max_prompt_size
|
max_tokens = chat_model.max_prompt_size
|
||||||
if state.offline_chat_processor_config is None:
|
if state.offline_chat_processor_config is None:
|
||||||
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
|
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model_name, max_tokens)
|
||||||
|
|
||||||
loaded_model = state.offline_chat_processor_config.loaded_model
|
loaded_model = state.offline_chat_processor_config.loaded_model
|
||||||
|
|
||||||
|
@ -424,18 +419,18 @@ async def extract_references_and_questions(
|
||||||
should_extract_questions=True,
|
should_extract_questions=True,
|
||||||
location_data=location_data,
|
location_data=location_data,
|
||||||
user=user,
|
user=user,
|
||||||
max_prompt_size=conversation_config.max_prompt_size,
|
max_prompt_size=chat_model.max_prompt_size,
|
||||||
personality_context=personality_context,
|
personality_context=personality_context,
|
||||||
query_files=query_files,
|
query_files=query_files,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
|
elif chat_model.model_type == ChatModel.ModelType.OPENAI:
|
||||||
api_key = conversation_config.ai_model_api.api_key
|
api_key = chat_model.ai_model_api.api_key
|
||||||
base_url = conversation_config.ai_model_api.api_base_url
|
base_url = chat_model.ai_model_api.api_base_url
|
||||||
chat_model = conversation_config.chat_model
|
chat_model_name = chat_model.name
|
||||||
inferred_queries = extract_questions(
|
inferred_queries = extract_questions(
|
||||||
defiltered_query,
|
defiltered_query,
|
||||||
model=chat_model,
|
model=chat_model_name,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
api_base_url=base_url,
|
api_base_url=base_url,
|
||||||
conversation_log=meta_log,
|
conversation_log=meta_log,
|
||||||
|
@ -447,13 +442,13 @@ async def extract_references_and_questions(
|
||||||
query_files=query_files,
|
query_files=query_files,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC:
|
elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC:
|
||||||
api_key = conversation_config.ai_model_api.api_key
|
api_key = chat_model.ai_model_api.api_key
|
||||||
chat_model = conversation_config.chat_model
|
chat_model_name = chat_model.name
|
||||||
inferred_queries = extract_questions_anthropic(
|
inferred_queries = extract_questions_anthropic(
|
||||||
defiltered_query,
|
defiltered_query,
|
||||||
query_images=query_images,
|
query_images=query_images,
|
||||||
model=chat_model,
|
model=chat_model_name,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
conversation_log=meta_log,
|
conversation_log=meta_log,
|
||||||
location_data=location_data,
|
location_data=location_data,
|
||||||
|
@ -463,17 +458,17 @@ async def extract_references_and_questions(
|
||||||
query_files=query_files,
|
query_files=query_files,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
|
elif chat_model.model_type == ChatModel.ModelType.GOOGLE:
|
||||||
api_key = conversation_config.ai_model_api.api_key
|
api_key = chat_model.ai_model_api.api_key
|
||||||
chat_model = conversation_config.chat_model
|
chat_model_name = chat_model.name
|
||||||
inferred_queries = extract_questions_gemini(
|
inferred_queries = extract_questions_gemini(
|
||||||
defiltered_query,
|
defiltered_query,
|
||||||
query_images=query_images,
|
query_images=query_images,
|
||||||
model=chat_model,
|
model=chat_model_name,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
conversation_log=meta_log,
|
conversation_log=meta_log,
|
||||||
location_data=location_data,
|
location_data=location_data,
|
||||||
max_tokens=conversation_config.max_prompt_size,
|
max_tokens=chat_model.max_prompt_size,
|
||||||
user=user,
|
user=user,
|
||||||
vision_enabled=vision_enabled,
|
vision_enabled=vision_enabled,
|
||||||
personality_context=personality_context,
|
personality_context=personality_context,
|
||||||
|
|
|
@ -62,7 +62,7 @@ async def all_agents(
|
||||||
"color": agent.style_color,
|
"color": agent.style_color,
|
||||||
"icon": agent.style_icon,
|
"icon": agent.style_icon,
|
||||||
"privacy_level": agent.privacy_level,
|
"privacy_level": agent.privacy_level,
|
||||||
"chat_model": agent.chat_model.chat_model,
|
"chat_model": agent.chat_model.name,
|
||||||
"files": file_names,
|
"files": file_names,
|
||||||
"input_tools": agent.input_tools,
|
"input_tools": agent.input_tools,
|
||||||
"output_modes": agent.output_modes,
|
"output_modes": agent.output_modes,
|
||||||
|
@ -150,7 +150,7 @@ async def get_agent(
|
||||||
"color": agent.style_color,
|
"color": agent.style_color,
|
||||||
"icon": agent.style_icon,
|
"icon": agent.style_icon,
|
||||||
"privacy_level": agent.privacy_level,
|
"privacy_level": agent.privacy_level,
|
||||||
"chat_model": agent.chat_model.chat_model,
|
"chat_model": agent.chat_model.name,
|
||||||
"files": file_names,
|
"files": file_names,
|
||||||
"input_tools": agent.input_tools,
|
"input_tools": agent.input_tools,
|
||||||
"output_modes": agent.output_modes,
|
"output_modes": agent.output_modes,
|
||||||
|
@ -225,7 +225,7 @@ async def create_agent(
|
||||||
"color": agent.style_color,
|
"color": agent.style_color,
|
||||||
"icon": agent.style_icon,
|
"icon": agent.style_icon,
|
||||||
"privacy_level": agent.privacy_level,
|
"privacy_level": agent.privacy_level,
|
||||||
"chat_model": agent.chat_model.chat_model,
|
"chat_model": agent.chat_model.name,
|
||||||
"files": body.files,
|
"files": body.files,
|
||||||
"input_tools": agent.input_tools,
|
"input_tools": agent.input_tools,
|
||||||
"output_modes": agent.output_modes,
|
"output_modes": agent.output_modes,
|
||||||
|
@ -286,7 +286,7 @@ async def update_agent(
|
||||||
"color": agent.style_color,
|
"color": agent.style_color,
|
||||||
"icon": agent.style_icon,
|
"icon": agent.style_icon,
|
||||||
"privacy_level": agent.privacy_level,
|
"privacy_level": agent.privacy_level,
|
||||||
"chat_model": agent.chat_model.chat_model,
|
"chat_model": agent.chat_model.name,
|
||||||
"files": body.files,
|
"files": body.files,
|
||||||
"input_tools": agent.input_tools,
|
"input_tools": agent.input_tools,
|
||||||
"output_modes": agent.output_modes,
|
"output_modes": agent.output_modes,
|
||||||
|
|
|
@ -58,7 +58,7 @@ from khoj.routers.helpers import (
|
||||||
is_ready_to_chat,
|
is_ready_to_chat,
|
||||||
read_chat_stream,
|
read_chat_stream,
|
||||||
update_telemetry_state,
|
update_telemetry_state,
|
||||||
validate_conversation_config,
|
validate_chat_model,
|
||||||
)
|
)
|
||||||
from khoj.routers.research import (
|
from khoj.routers.research import (
|
||||||
InformationCollectionIteration,
|
InformationCollectionIteration,
|
||||||
|
@ -205,7 +205,7 @@ def chat_history(
|
||||||
n: Optional[int] = None,
|
n: Optional[int] = None,
|
||||||
):
|
):
|
||||||
user = request.user.object
|
user = request.user.object
|
||||||
validate_conversation_config(user)
|
validate_chat_model(user)
|
||||||
|
|
||||||
# Load Conversation History
|
# Load Conversation History
|
||||||
conversation = ConversationAdapters.get_conversation_by_user(
|
conversation = ConversationAdapters.get_conversation_by_user(
|
||||||
|
@ -898,10 +898,10 @@ async def chat(
|
||||||
custom_filters = []
|
custom_filters = []
|
||||||
if conversation_commands == [ConversationCommand.Help]:
|
if conversation_commands == [ConversationCommand.Help]:
|
||||||
if not q:
|
if not q:
|
||||||
conversation_config = await ConversationAdapters.aget_user_conversation_config(user)
|
chat_model = await ConversationAdapters.aget_user_chat_model(user)
|
||||||
if conversation_config == None:
|
if chat_model == None:
|
||||||
conversation_config = await ConversationAdapters.aget_default_conversation_config(user)
|
chat_model = await ConversationAdapters.aget_default_chat_model(user)
|
||||||
model_type = conversation_config.model_type
|
model_type = chat_model.model_type
|
||||||
formatted_help = help_message.format(model=model_type, version=state.khoj_version, device=get_device())
|
formatted_help = help_message.format(model=model_type, version=state.khoj_version, device=get_device())
|
||||||
async for result in send_llm_response(formatted_help, tracer.get("usage")):
|
async for result in send_llm_response(formatted_help, tracer.get("usage")):
|
||||||
yield result
|
yield result
|
||||||
|
|
|
@ -24,7 +24,7 @@ def get_chat_model_options(
|
||||||
|
|
||||||
all_conversation_options = list()
|
all_conversation_options = list()
|
||||||
for conversation_option in conversation_options:
|
for conversation_option in conversation_options:
|
||||||
all_conversation_options.append({"chat_model": conversation_option.chat_model, "id": conversation_option.id})
|
all_conversation_options.append({"chat_model": conversation_option.name, "id": conversation_option.id})
|
||||||
|
|
||||||
return Response(content=json.dumps(all_conversation_options), media_type="application/json", status_code=200)
|
return Response(content=json.dumps(all_conversation_options), media_type="application/json", status_code=200)
|
||||||
|
|
||||||
|
@ -37,12 +37,12 @@ def get_user_chat_model(
|
||||||
):
|
):
|
||||||
user = request.user.object
|
user = request.user.object
|
||||||
|
|
||||||
chat_model = ConversationAdapters.get_conversation_config(user)
|
chat_model = ConversationAdapters.get_chat_model(user)
|
||||||
|
|
||||||
if chat_model is None:
|
if chat_model is None:
|
||||||
chat_model = ConversationAdapters.get_default_conversation_config(user)
|
chat_model = ConversationAdapters.get_default_chat_model(user)
|
||||||
|
|
||||||
return Response(status_code=200, content=json.dumps({"id": chat_model.id, "chat_model": chat_model.chat_model}))
|
return Response(status_code=200, content=json.dumps({"id": chat_model.id, "chat_model": chat_model.name}))
|
||||||
|
|
||||||
|
|
||||||
@api_model.post("/chat", status_code=200)
|
@api_model.post("/chat", status_code=200)
|
||||||
|
|
|
@ -56,7 +56,7 @@ from khoj.database.adapters import (
|
||||||
)
|
)
|
||||||
from khoj.database.models import (
|
from khoj.database.models import (
|
||||||
Agent,
|
Agent,
|
||||||
ChatModelOptions,
|
ChatModel,
|
||||||
ClientApplication,
|
ClientApplication,
|
||||||
Conversation,
|
Conversation,
|
||||||
GithubConfig,
|
GithubConfig,
|
||||||
|
@ -133,40 +133,40 @@ def is_query_empty(query: str) -> bool:
|
||||||
return is_none_or_empty(query.strip())
|
return is_none_or_empty(query.strip())
|
||||||
|
|
||||||
|
|
||||||
def validate_conversation_config(user: KhojUser):
|
def validate_chat_model(user: KhojUser):
|
||||||
default_config = ConversationAdapters.get_default_conversation_config(user)
|
default_chat_model = ConversationAdapters.get_default_chat_model(user)
|
||||||
|
|
||||||
if default_config is None:
|
if default_chat_model is None:
|
||||||
raise HTTPException(status_code=500, detail="Contact the server administrator to add a chat model.")
|
raise HTTPException(status_code=500, detail="Contact the server administrator to add a chat model.")
|
||||||
|
|
||||||
if default_config.model_type == "openai" and not default_config.ai_model_api:
|
if default_chat_model.model_type == "openai" and not default_chat_model.ai_model_api:
|
||||||
raise HTTPException(status_code=500, detail="Contact the server administrator to add a chat model.")
|
raise HTTPException(status_code=500, detail="Contact the server administrator to add a chat model.")
|
||||||
|
|
||||||
|
|
||||||
async def is_ready_to_chat(user: KhojUser):
|
async def is_ready_to_chat(user: KhojUser):
|
||||||
user_conversation_config = await ConversationAdapters.aget_user_conversation_config(user)
|
user_chat_model = await ConversationAdapters.aget_user_chat_model(user)
|
||||||
if user_conversation_config == None:
|
if user_chat_model == None:
|
||||||
user_conversation_config = await ConversationAdapters.aget_default_conversation_config(user)
|
user_chat_model = await ConversationAdapters.aget_default_chat_model(user)
|
||||||
|
|
||||||
if user_conversation_config and user_conversation_config.model_type == ChatModelOptions.ModelType.OFFLINE:
|
if user_chat_model and user_chat_model.model_type == ChatModel.ModelType.OFFLINE:
|
||||||
chat_model = user_conversation_config.chat_model
|
chat_model_name = user_chat_model.name
|
||||||
max_tokens = user_conversation_config.max_prompt_size
|
max_tokens = user_chat_model.max_prompt_size
|
||||||
if state.offline_chat_processor_config is None:
|
if state.offline_chat_processor_config is None:
|
||||||
logger.info("Loading Offline Chat Model...")
|
logger.info("Loading Offline Chat Model...")
|
||||||
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
|
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model_name, max_tokens)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if (
|
if (
|
||||||
user_conversation_config
|
user_chat_model
|
||||||
and (
|
and (
|
||||||
user_conversation_config.model_type
|
user_chat_model.model_type
|
||||||
in [
|
in [
|
||||||
ChatModelOptions.ModelType.OPENAI,
|
ChatModel.ModelType.OPENAI,
|
||||||
ChatModelOptions.ModelType.ANTHROPIC,
|
ChatModel.ModelType.ANTHROPIC,
|
||||||
ChatModelOptions.ModelType.GOOGLE,
|
ChatModel.ModelType.GOOGLE,
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
and user_conversation_config.ai_model_api
|
and user_chat_model.ai_model_api
|
||||||
):
|
):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@ -942,120 +942,124 @@ async def send_message_to_model_wrapper(
|
||||||
query_files: str = None,
|
query_files: str = None,
|
||||||
tracer: dict = {},
|
tracer: dict = {},
|
||||||
):
|
):
|
||||||
conversation_config: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config(user)
|
chat_model: ChatModel = await ConversationAdapters.aget_default_chat_model(user)
|
||||||
vision_available = conversation_config.vision_enabled
|
vision_available = chat_model.vision_enabled
|
||||||
if not vision_available and query_images:
|
if not vision_available and query_images:
|
||||||
logger.warning(f"Vision is not enabled for default model: {conversation_config.chat_model}.")
|
logger.warning(f"Vision is not enabled for default model: {chat_model.name}.")
|
||||||
vision_enabled_config = await ConversationAdapters.aget_vision_enabled_config()
|
vision_enabled_config = await ConversationAdapters.aget_vision_enabled_config()
|
||||||
if vision_enabled_config:
|
if vision_enabled_config:
|
||||||
conversation_config = vision_enabled_config
|
chat_model = vision_enabled_config
|
||||||
vision_available = True
|
vision_available = True
|
||||||
if vision_available and query_images:
|
if vision_available and query_images:
|
||||||
logger.info(f"Using {conversation_config.chat_model} model to understand {len(query_images)} images.")
|
logger.info(f"Using {chat_model.name} model to understand {len(query_images)} images.")
|
||||||
|
|
||||||
subscribed = await ais_user_subscribed(user)
|
subscribed = await ais_user_subscribed(user)
|
||||||
chat_model = conversation_config.chat_model
|
chat_model_name = chat_model.name
|
||||||
max_tokens = (
|
max_tokens = (
|
||||||
conversation_config.subscribed_max_prompt_size
|
chat_model.subscribed_max_prompt_size
|
||||||
if subscribed and conversation_config.subscribed_max_prompt_size
|
if subscribed and chat_model.subscribed_max_prompt_size
|
||||||
else conversation_config.max_prompt_size
|
else chat_model.max_prompt_size
|
||||||
)
|
)
|
||||||
tokenizer = conversation_config.tokenizer
|
tokenizer = chat_model.tokenizer
|
||||||
model_type = conversation_config.model_type
|
model_type = chat_model.model_type
|
||||||
vision_available = conversation_config.vision_enabled
|
vision_available = chat_model.vision_enabled
|
||||||
|
|
||||||
if model_type == ChatModelOptions.ModelType.OFFLINE:
|
if model_type == ChatModel.ModelType.OFFLINE:
|
||||||
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
|
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
|
||||||
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
|
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model_name, max_tokens)
|
||||||
|
|
||||||
loaded_model = state.offline_chat_processor_config.loaded_model
|
loaded_model = state.offline_chat_processor_config.loaded_model
|
||||||
truncated_messages = generate_chatml_messages_with_context(
|
truncated_messages = generate_chatml_messages_with_context(
|
||||||
user_message=query,
|
user_message=query,
|
||||||
context_message=context,
|
context_message=context,
|
||||||
system_message=system_message,
|
system_message=system_message,
|
||||||
model_name=chat_model,
|
model_name=chat_model_name,
|
||||||
loaded_model=loaded_model,
|
loaded_model=loaded_model,
|
||||||
tokenizer_name=tokenizer,
|
tokenizer_name=tokenizer,
|
||||||
max_prompt_size=max_tokens,
|
max_prompt_size=max_tokens,
|
||||||
vision_enabled=vision_available,
|
vision_enabled=vision_available,
|
||||||
model_type=conversation_config.model_type,
|
model_type=chat_model.model_type,
|
||||||
query_files=query_files,
|
query_files=query_files,
|
||||||
)
|
)
|
||||||
|
|
||||||
return send_message_to_model_offline(
|
return send_message_to_model_offline(
|
||||||
messages=truncated_messages,
|
messages=truncated_messages,
|
||||||
loaded_model=loaded_model,
|
loaded_model=loaded_model,
|
||||||
model=chat_model,
|
model=chat_model_name,
|
||||||
max_prompt_size=max_tokens,
|
max_prompt_size=max_tokens,
|
||||||
streaming=False,
|
streaming=False,
|
||||||
response_type=response_type,
|
response_type=response_type,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif model_type == ChatModelOptions.ModelType.OPENAI:
|
elif model_type == ChatModel.ModelType.OPENAI:
|
||||||
openai_chat_config = conversation_config.ai_model_api
|
openai_chat_config = chat_model.ai_model_api
|
||||||
api_key = openai_chat_config.api_key
|
api_key = openai_chat_config.api_key
|
||||||
api_base_url = openai_chat_config.api_base_url
|
api_base_url = openai_chat_config.api_base_url
|
||||||
truncated_messages = generate_chatml_messages_with_context(
|
truncated_messages = generate_chatml_messages_with_context(
|
||||||
user_message=query,
|
user_message=query,
|
||||||
context_message=context,
|
context_message=context,
|
||||||
system_message=system_message,
|
system_message=system_message,
|
||||||
model_name=chat_model,
|
model_name=chat_model_name,
|
||||||
max_prompt_size=max_tokens,
|
max_prompt_size=max_tokens,
|
||||||
tokenizer_name=tokenizer,
|
tokenizer_name=tokenizer,
|
||||||
vision_enabled=vision_available,
|
vision_enabled=vision_available,
|
||||||
query_images=query_images,
|
query_images=query_images,
|
||||||
model_type=conversation_config.model_type,
|
model_type=chat_model.model_type,
|
||||||
query_files=query_files,
|
query_files=query_files,
|
||||||
)
|
)
|
||||||
|
|
||||||
return send_message_to_model(
|
return send_message_to_model(
|
||||||
messages=truncated_messages,
|
messages=truncated_messages,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
model=chat_model,
|
model=chat_model_name,
|
||||||
response_type=response_type,
|
response_type=response_type,
|
||||||
api_base_url=api_base_url,
|
api_base_url=api_base_url,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
elif model_type == ChatModelOptions.ModelType.ANTHROPIC:
|
elif model_type == ChatModel.ModelType.ANTHROPIC:
|
||||||
api_key = conversation_config.ai_model_api.api_key
|
api_key = chat_model.ai_model_api.api_key
|
||||||
truncated_messages = generate_chatml_messages_with_context(
|
truncated_messages = generate_chatml_messages_with_context(
|
||||||
user_message=query,
|
user_message=query,
|
||||||
context_message=context,
|
context_message=context,
|
||||||
system_message=system_message,
|
system_message=system_message,
|
||||||
model_name=chat_model,
|
model_name=chat_model_name,
|
||||||
max_prompt_size=max_tokens,
|
max_prompt_size=max_tokens,
|
||||||
tokenizer_name=tokenizer,
|
tokenizer_name=tokenizer,
|
||||||
vision_enabled=vision_available,
|
vision_enabled=vision_available,
|
||||||
query_images=query_images,
|
query_images=query_images,
|
||||||
model_type=conversation_config.model_type,
|
model_type=chat_model.model_type,
|
||||||
query_files=query_files,
|
query_files=query_files,
|
||||||
)
|
)
|
||||||
|
|
||||||
return anthropic_send_message_to_model(
|
return anthropic_send_message_to_model(
|
||||||
messages=truncated_messages,
|
messages=truncated_messages,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
model=chat_model,
|
model=chat_model_name,
|
||||||
response_type=response_type,
|
response_type=response_type,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
elif model_type == ChatModelOptions.ModelType.GOOGLE:
|
elif model_type == ChatModel.ModelType.GOOGLE:
|
||||||
api_key = conversation_config.ai_model_api.api_key
|
api_key = chat_model.ai_model_api.api_key
|
||||||
truncated_messages = generate_chatml_messages_with_context(
|
truncated_messages = generate_chatml_messages_with_context(
|
||||||
user_message=query,
|
user_message=query,
|
||||||
context_message=context,
|
context_message=context,
|
||||||
system_message=system_message,
|
system_message=system_message,
|
||||||
model_name=chat_model,
|
model_name=chat_model_name,
|
||||||
max_prompt_size=max_tokens,
|
max_prompt_size=max_tokens,
|
||||||
tokenizer_name=tokenizer,
|
tokenizer_name=tokenizer,
|
||||||
vision_enabled=vision_available,
|
vision_enabled=vision_available,
|
||||||
query_images=query_images,
|
query_images=query_images,
|
||||||
model_type=conversation_config.model_type,
|
model_type=chat_model.model_type,
|
||||||
query_files=query_files,
|
query_files=query_files,
|
||||||
)
|
)
|
||||||
|
|
||||||
return gemini_send_message_to_model(
|
return gemini_send_message_to_model(
|
||||||
messages=truncated_messages, api_key=api_key, model=chat_model, response_type=response_type, tracer=tracer
|
messages=truncated_messages,
|
||||||
|
api_key=api_key,
|
||||||
|
model=chat_model_name,
|
||||||
|
response_type=response_type,
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise HTTPException(status_code=500, detail="Invalid conversation config")
|
raise HTTPException(status_code=500, detail="Invalid conversation config")
|
||||||
|
@ -1069,99 +1073,99 @@ def send_message_to_model_wrapper_sync(
|
||||||
query_files: str = "",
|
query_files: str = "",
|
||||||
tracer: dict = {},
|
tracer: dict = {},
|
||||||
):
|
):
|
||||||
conversation_config: ChatModelOptions = ConversationAdapters.get_default_conversation_config(user)
|
chat_model: ChatModel = ConversationAdapters.get_default_chat_model(user)
|
||||||
|
|
||||||
if conversation_config is None:
|
if chat_model is None:
|
||||||
raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.")
|
raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.")
|
||||||
|
|
||||||
chat_model = conversation_config.chat_model
|
chat_model_name = chat_model.name
|
||||||
max_tokens = conversation_config.max_prompt_size
|
max_tokens = chat_model.max_prompt_size
|
||||||
vision_available = conversation_config.vision_enabled
|
vision_available = chat_model.vision_enabled
|
||||||
|
|
||||||
if conversation_config.model_type == ChatModelOptions.ModelType.OFFLINE:
|
if chat_model.model_type == ChatModel.ModelType.OFFLINE:
|
||||||
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
|
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
|
||||||
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
|
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model_name, max_tokens)
|
||||||
|
|
||||||
loaded_model = state.offline_chat_processor_config.loaded_model
|
loaded_model = state.offline_chat_processor_config.loaded_model
|
||||||
truncated_messages = generate_chatml_messages_with_context(
|
truncated_messages = generate_chatml_messages_with_context(
|
||||||
user_message=message,
|
user_message=message,
|
||||||
system_message=system_message,
|
system_message=system_message,
|
||||||
model_name=chat_model,
|
model_name=chat_model_name,
|
||||||
loaded_model=loaded_model,
|
loaded_model=loaded_model,
|
||||||
max_prompt_size=max_tokens,
|
max_prompt_size=max_tokens,
|
||||||
vision_enabled=vision_available,
|
vision_enabled=vision_available,
|
||||||
model_type=conversation_config.model_type,
|
model_type=chat_model.model_type,
|
||||||
query_files=query_files,
|
query_files=query_files,
|
||||||
)
|
)
|
||||||
|
|
||||||
return send_message_to_model_offline(
|
return send_message_to_model_offline(
|
||||||
messages=truncated_messages,
|
messages=truncated_messages,
|
||||||
loaded_model=loaded_model,
|
loaded_model=loaded_model,
|
||||||
model=chat_model,
|
model=chat_model_name,
|
||||||
max_prompt_size=max_tokens,
|
max_prompt_size=max_tokens,
|
||||||
streaming=False,
|
streaming=False,
|
||||||
response_type=response_type,
|
response_type=response_type,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
|
elif chat_model.model_type == ChatModel.ModelType.OPENAI:
|
||||||
api_key = conversation_config.ai_model_api.api_key
|
api_key = chat_model.ai_model_api.api_key
|
||||||
truncated_messages = generate_chatml_messages_with_context(
|
truncated_messages = generate_chatml_messages_with_context(
|
||||||
user_message=message,
|
user_message=message,
|
||||||
system_message=system_message,
|
system_message=system_message,
|
||||||
model_name=chat_model,
|
model_name=chat_model_name,
|
||||||
max_prompt_size=max_tokens,
|
max_prompt_size=max_tokens,
|
||||||
vision_enabled=vision_available,
|
vision_enabled=vision_available,
|
||||||
model_type=conversation_config.model_type,
|
model_type=chat_model.model_type,
|
||||||
query_files=query_files,
|
query_files=query_files,
|
||||||
)
|
)
|
||||||
|
|
||||||
openai_response = send_message_to_model(
|
openai_response = send_message_to_model(
|
||||||
messages=truncated_messages,
|
messages=truncated_messages,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
model=chat_model,
|
model=chat_model_name,
|
||||||
response_type=response_type,
|
response_type=response_type,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
return openai_response
|
return openai_response
|
||||||
|
|
||||||
elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC:
|
elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC:
|
||||||
api_key = conversation_config.ai_model_api.api_key
|
api_key = chat_model.ai_model_api.api_key
|
||||||
truncated_messages = generate_chatml_messages_with_context(
|
truncated_messages = generate_chatml_messages_with_context(
|
||||||
user_message=message,
|
user_message=message,
|
||||||
system_message=system_message,
|
system_message=system_message,
|
||||||
model_name=chat_model,
|
model_name=chat_model_name,
|
||||||
max_prompt_size=max_tokens,
|
max_prompt_size=max_tokens,
|
||||||
vision_enabled=vision_available,
|
vision_enabled=vision_available,
|
||||||
model_type=conversation_config.model_type,
|
model_type=chat_model.model_type,
|
||||||
query_files=query_files,
|
query_files=query_files,
|
||||||
)
|
)
|
||||||
|
|
||||||
return anthropic_send_message_to_model(
|
return anthropic_send_message_to_model(
|
||||||
messages=truncated_messages,
|
messages=truncated_messages,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
model=chat_model,
|
model=chat_model_name,
|
||||||
response_type=response_type,
|
response_type=response_type,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
|
elif chat_model.model_type == ChatModel.ModelType.GOOGLE:
|
||||||
api_key = conversation_config.ai_model_api.api_key
|
api_key = chat_model.ai_model_api.api_key
|
||||||
truncated_messages = generate_chatml_messages_with_context(
|
truncated_messages = generate_chatml_messages_with_context(
|
||||||
user_message=message,
|
user_message=message,
|
||||||
system_message=system_message,
|
system_message=system_message,
|
||||||
model_name=chat_model,
|
model_name=chat_model_name,
|
||||||
max_prompt_size=max_tokens,
|
max_prompt_size=max_tokens,
|
||||||
vision_enabled=vision_available,
|
vision_enabled=vision_available,
|
||||||
model_type=conversation_config.model_type,
|
model_type=chat_model.model_type,
|
||||||
query_files=query_files,
|
query_files=query_files,
|
||||||
)
|
)
|
||||||
|
|
||||||
return gemini_send_message_to_model(
|
return gemini_send_message_to_model(
|
||||||
messages=truncated_messages,
|
messages=truncated_messages,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
model=chat_model,
|
model=chat_model_name,
|
||||||
response_type=response_type,
|
response_type=response_type,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
@ -1229,15 +1233,15 @@ def generate_chat_response(
|
||||||
online_results = {}
|
online_results = {}
|
||||||
code_results = {}
|
code_results = {}
|
||||||
|
|
||||||
conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation)
|
chat_model = ConversationAdapters.get_valid_chat_model(user, conversation)
|
||||||
vision_available = conversation_config.vision_enabled
|
vision_available = chat_model.vision_enabled
|
||||||
if not vision_available and query_images:
|
if not vision_available and query_images:
|
||||||
vision_enabled_config = ConversationAdapters.get_vision_enabled_config()
|
vision_enabled_config = ConversationAdapters.get_vision_enabled_config()
|
||||||
if vision_enabled_config:
|
if vision_enabled_config:
|
||||||
conversation_config = vision_enabled_config
|
chat_model = vision_enabled_config
|
||||||
vision_available = True
|
vision_available = True
|
||||||
|
|
||||||
if conversation_config.model_type == "offline":
|
if chat_model.model_type == "offline":
|
||||||
loaded_model = state.offline_chat_processor_config.loaded_model
|
loaded_model = state.offline_chat_processor_config.loaded_model
|
||||||
chat_response = converse_offline(
|
chat_response = converse_offline(
|
||||||
user_query=query_to_run,
|
user_query=query_to_run,
|
||||||
|
@ -1247,9 +1251,9 @@ def generate_chat_response(
|
||||||
conversation_log=meta_log,
|
conversation_log=meta_log,
|
||||||
completion_func=partial_completion,
|
completion_func=partial_completion,
|
||||||
conversation_commands=conversation_commands,
|
conversation_commands=conversation_commands,
|
||||||
model=conversation_config.chat_model,
|
model=chat_model.name,
|
||||||
max_prompt_size=conversation_config.max_prompt_size,
|
max_prompt_size=chat_model.max_prompt_size,
|
||||||
tokenizer_name=conversation_config.tokenizer,
|
tokenizer_name=chat_model.tokenizer,
|
||||||
location_data=location_data,
|
location_data=location_data,
|
||||||
user_name=user_name,
|
user_name=user_name,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
|
@ -1259,10 +1263,10 @@ def generate_chat_response(
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
|
elif chat_model.model_type == ChatModel.ModelType.OPENAI:
|
||||||
openai_chat_config = conversation_config.ai_model_api
|
openai_chat_config = chat_model.ai_model_api
|
||||||
api_key = openai_chat_config.api_key
|
api_key = openai_chat_config.api_key
|
||||||
chat_model = conversation_config.chat_model
|
chat_model_name = chat_model.name
|
||||||
chat_response = converse_openai(
|
chat_response = converse_openai(
|
||||||
compiled_references,
|
compiled_references,
|
||||||
query_to_run,
|
query_to_run,
|
||||||
|
@ -1270,13 +1274,13 @@ def generate_chat_response(
|
||||||
online_results=online_results,
|
online_results=online_results,
|
||||||
code_results=code_results,
|
code_results=code_results,
|
||||||
conversation_log=meta_log,
|
conversation_log=meta_log,
|
||||||
model=chat_model,
|
model=chat_model_name,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
api_base_url=openai_chat_config.api_base_url,
|
api_base_url=openai_chat_config.api_base_url,
|
||||||
completion_func=partial_completion,
|
completion_func=partial_completion,
|
||||||
conversation_commands=conversation_commands,
|
conversation_commands=conversation_commands,
|
||||||
max_prompt_size=conversation_config.max_prompt_size,
|
max_prompt_size=chat_model.max_prompt_size,
|
||||||
tokenizer_name=conversation_config.tokenizer,
|
tokenizer_name=chat_model.tokenizer,
|
||||||
location_data=location_data,
|
location_data=location_data,
|
||||||
user_name=user_name,
|
user_name=user_name,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
|
@ -1288,8 +1292,8 @@ def generate_chat_response(
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC:
|
elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC:
|
||||||
api_key = conversation_config.ai_model_api.api_key
|
api_key = chat_model.ai_model_api.api_key
|
||||||
chat_response = converse_anthropic(
|
chat_response = converse_anthropic(
|
||||||
compiled_references,
|
compiled_references,
|
||||||
query_to_run,
|
query_to_run,
|
||||||
|
@ -1297,12 +1301,12 @@ def generate_chat_response(
|
||||||
online_results=online_results,
|
online_results=online_results,
|
||||||
code_results=code_results,
|
code_results=code_results,
|
||||||
conversation_log=meta_log,
|
conversation_log=meta_log,
|
||||||
model=conversation_config.chat_model,
|
model=chat_model.name,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
completion_func=partial_completion,
|
completion_func=partial_completion,
|
||||||
conversation_commands=conversation_commands,
|
conversation_commands=conversation_commands,
|
||||||
max_prompt_size=conversation_config.max_prompt_size,
|
max_prompt_size=chat_model.max_prompt_size,
|
||||||
tokenizer_name=conversation_config.tokenizer,
|
tokenizer_name=chat_model.tokenizer,
|
||||||
location_data=location_data,
|
location_data=location_data,
|
||||||
user_name=user_name,
|
user_name=user_name,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
|
@ -1313,20 +1317,20 @@ def generate_chat_response(
|
||||||
program_execution_context=program_execution_context,
|
program_execution_context=program_execution_context,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
|
elif chat_model.model_type == ChatModel.ModelType.GOOGLE:
|
||||||
api_key = conversation_config.ai_model_api.api_key
|
api_key = chat_model.ai_model_api.api_key
|
||||||
chat_response = converse_gemini(
|
chat_response = converse_gemini(
|
||||||
compiled_references,
|
compiled_references,
|
||||||
query_to_run,
|
query_to_run,
|
||||||
online_results,
|
online_results,
|
||||||
code_results,
|
code_results,
|
||||||
meta_log,
|
meta_log,
|
||||||
model=conversation_config.chat_model,
|
model=chat_model.name,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
completion_func=partial_completion,
|
completion_func=partial_completion,
|
||||||
conversation_commands=conversation_commands,
|
conversation_commands=conversation_commands,
|
||||||
max_prompt_size=conversation_config.max_prompt_size,
|
max_prompt_size=chat_model.max_prompt_size,
|
||||||
tokenizer_name=conversation_config.tokenizer,
|
tokenizer_name=chat_model.tokenizer,
|
||||||
location_data=location_data,
|
location_data=location_data,
|
||||||
user_name=user_name,
|
user_name=user_name,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
|
@ -1339,7 +1343,7 @@ def generate_chat_response(
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
metadata.update({"chat_model": conversation_config.chat_model})
|
metadata.update({"chat_model": chat_model.name})
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(e, exc_info=True)
|
logger.error(e, exc_info=True)
|
||||||
|
@ -1939,13 +1943,13 @@ def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False)
|
||||||
current_notion_config = get_user_notion_config(user)
|
current_notion_config = get_user_notion_config(user)
|
||||||
notion_token = current_notion_config.token if current_notion_config else ""
|
notion_token = current_notion_config.token if current_notion_config else ""
|
||||||
|
|
||||||
selected_chat_model_config = ConversationAdapters.get_conversation_config(
|
selected_chat_model_config = ConversationAdapters.get_chat_model(
|
||||||
user
|
user
|
||||||
) or ConversationAdapters.get_default_conversation_config(user)
|
) or ConversationAdapters.get_default_chat_model(user)
|
||||||
chat_models = ConversationAdapters.get_conversation_processor_options().all()
|
chat_models = ConversationAdapters.get_conversation_processor_options().all()
|
||||||
chat_model_options = list()
|
chat_model_options = list()
|
||||||
for chat_model in chat_models:
|
for chat_model in chat_models:
|
||||||
chat_model_options.append({"name": chat_model.chat_model, "id": chat_model.id})
|
chat_model_options.append({"name": chat_model.name, "id": chat_model.id})
|
||||||
|
|
||||||
selected_paint_model_config = ConversationAdapters.get_user_text_to_image_model_config(user)
|
selected_paint_model_config = ConversationAdapters.get_user_text_to_image_model_config(user)
|
||||||
paint_model_options = ConversationAdapters.get_text_to_image_model_options().all()
|
paint_model_options = ConversationAdapters.get_text_to_image_model_options().all()
|
||||||
|
|
|
@ -7,7 +7,7 @@ import openai
|
||||||
from khoj.database.adapters import ConversationAdapters
|
from khoj.database.adapters import ConversationAdapters
|
||||||
from khoj.database.models import (
|
from khoj.database.models import (
|
||||||
AiModelApi,
|
AiModelApi,
|
||||||
ChatModelOptions,
|
ChatModel,
|
||||||
KhojUser,
|
KhojUser,
|
||||||
SpeechToTextModelOptions,
|
SpeechToTextModelOptions,
|
||||||
TextToImageModelConfig,
|
TextToImageModelConfig,
|
||||||
|
@ -63,7 +63,7 @@ def initialization(interactive: bool = True):
|
||||||
|
|
||||||
# Set up OpenAI's online chat models
|
# Set up OpenAI's online chat models
|
||||||
openai_configured, openai_provider = _setup_chat_model_provider(
|
openai_configured, openai_provider = _setup_chat_model_provider(
|
||||||
ChatModelOptions.ModelType.OPENAI,
|
ChatModel.ModelType.OPENAI,
|
||||||
default_chat_models,
|
default_chat_models,
|
||||||
default_api_key=openai_api_key,
|
default_api_key=openai_api_key,
|
||||||
api_base_url=openai_api_base,
|
api_base_url=openai_api_base,
|
||||||
|
@ -105,7 +105,7 @@ def initialization(interactive: bool = True):
|
||||||
|
|
||||||
# Set up Google's Gemini online chat models
|
# Set up Google's Gemini online chat models
|
||||||
_setup_chat_model_provider(
|
_setup_chat_model_provider(
|
||||||
ChatModelOptions.ModelType.GOOGLE,
|
ChatModel.ModelType.GOOGLE,
|
||||||
default_gemini_chat_models,
|
default_gemini_chat_models,
|
||||||
default_api_key=os.getenv("GEMINI_API_KEY"),
|
default_api_key=os.getenv("GEMINI_API_KEY"),
|
||||||
vision_enabled=True,
|
vision_enabled=True,
|
||||||
|
@ -116,7 +116,7 @@ def initialization(interactive: bool = True):
|
||||||
|
|
||||||
# Set up Anthropic's online chat models
|
# Set up Anthropic's online chat models
|
||||||
_setup_chat_model_provider(
|
_setup_chat_model_provider(
|
||||||
ChatModelOptions.ModelType.ANTHROPIC,
|
ChatModel.ModelType.ANTHROPIC,
|
||||||
default_anthropic_chat_models,
|
default_anthropic_chat_models,
|
||||||
default_api_key=os.getenv("ANTHROPIC_API_KEY"),
|
default_api_key=os.getenv("ANTHROPIC_API_KEY"),
|
||||||
vision_enabled=True,
|
vision_enabled=True,
|
||||||
|
@ -126,7 +126,7 @@ def initialization(interactive: bool = True):
|
||||||
|
|
||||||
# Set up offline chat models
|
# Set up offline chat models
|
||||||
_setup_chat_model_provider(
|
_setup_chat_model_provider(
|
||||||
ChatModelOptions.ModelType.OFFLINE,
|
ChatModel.ModelType.OFFLINE,
|
||||||
default_offline_chat_models,
|
default_offline_chat_models,
|
||||||
default_api_key=None,
|
default_api_key=None,
|
||||||
vision_enabled=False,
|
vision_enabled=False,
|
||||||
|
@ -135,9 +135,9 @@ def initialization(interactive: bool = True):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Explicitly set default chat model
|
# Explicitly set default chat model
|
||||||
chat_models_configured = ChatModelOptions.objects.count()
|
chat_models_configured = ChatModel.objects.count()
|
||||||
if chat_models_configured > 0:
|
if chat_models_configured > 0:
|
||||||
default_chat_model_name = ChatModelOptions.objects.first().chat_model
|
default_chat_model_name = ChatModel.objects.first().name
|
||||||
# If there are multiple chat models, ask the user to choose the default chat model
|
# If there are multiple chat models, ask the user to choose the default chat model
|
||||||
if chat_models_configured > 1 and interactive:
|
if chat_models_configured > 1 and interactive:
|
||||||
user_chat_model_name = input(
|
user_chat_model_name = input(
|
||||||
|
@ -147,7 +147,7 @@ def initialization(interactive: bool = True):
|
||||||
user_chat_model_name = None
|
user_chat_model_name = None
|
||||||
|
|
||||||
# If the user's choice is valid, set it as the default chat model
|
# If the user's choice is valid, set it as the default chat model
|
||||||
if user_chat_model_name and ChatModelOptions.objects.filter(chat_model=user_chat_model_name).exists():
|
if user_chat_model_name and ChatModel.objects.filter(name=user_chat_model_name).exists():
|
||||||
default_chat_model_name = user_chat_model_name
|
default_chat_model_name = user_chat_model_name
|
||||||
|
|
||||||
logger.info("🗣️ Chat model configuration complete")
|
logger.info("🗣️ Chat model configuration complete")
|
||||||
|
@ -171,7 +171,7 @@ def initialization(interactive: bool = True):
|
||||||
logger.info(f"🗣️ Offline speech to text model configured to {offline_speech2text_model}")
|
logger.info(f"🗣️ Offline speech to text model configured to {offline_speech2text_model}")
|
||||||
|
|
||||||
def _setup_chat_model_provider(
|
def _setup_chat_model_provider(
|
||||||
model_type: ChatModelOptions.ModelType,
|
model_type: ChatModel.ModelType,
|
||||||
default_chat_models: list,
|
default_chat_models: list,
|
||||||
default_api_key: str,
|
default_api_key: str,
|
||||||
interactive: bool,
|
interactive: bool,
|
||||||
|
@ -226,7 +226,7 @@ def initialization(interactive: bool = True):
|
||||||
"ai_model_api": ai_model_api,
|
"ai_model_api": ai_model_api,
|
||||||
}
|
}
|
||||||
|
|
||||||
ChatModelOptions.objects.create(**chat_model_options)
|
ChatModel.objects.create(**chat_model_options)
|
||||||
|
|
||||||
logger.info(f"🗣️ {provider_name} chat model configuration complete")
|
logger.info(f"🗣️ {provider_name} chat model configuration complete")
|
||||||
return True, ai_model_api
|
return True, ai_model_api
|
||||||
|
@ -250,16 +250,16 @@ def initialization(interactive: bool = True):
|
||||||
available_models = [model.id for model in openai_client.models.list()]
|
available_models = [model.id for model in openai_client.models.list()]
|
||||||
|
|
||||||
# Get existing chat model options for this config
|
# Get existing chat model options for this config
|
||||||
existing_models = ChatModelOptions.objects.filter(
|
existing_models = ChatModel.objects.filter(
|
||||||
ai_model_api=config, model_type=ChatModelOptions.ModelType.OPENAI
|
ai_model_api=config, model_type=ChatModel.ModelType.OPENAI
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add new models
|
# Add new models
|
||||||
for model in available_models:
|
for model in available_models:
|
||||||
if not existing_models.filter(chat_model=model).exists():
|
if not existing_models.filter(name=model).exists():
|
||||||
ChatModelOptions.objects.create(
|
ChatModel.objects.create(
|
||||||
chat_model=model,
|
name=model,
|
||||||
model_type=ChatModelOptions.ModelType.OPENAI,
|
model_type=ChatModel.ModelType.OPENAI,
|
||||||
max_prompt_size=model_to_prompt_size.get(model),
|
max_prompt_size=model_to_prompt_size.get(model),
|
||||||
vision_enabled=model in default_openai_chat_models,
|
vision_enabled=model in default_openai_chat_models,
|
||||||
tokenizer=model_to_tokenizer.get(model),
|
tokenizer=model_to_tokenizer.get(model),
|
||||||
|
@ -284,7 +284,7 @@ def initialization(interactive: bool = True):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"🚨 Failed to create admin user: {e}", exc_info=True)
|
logger.error(f"🚨 Failed to create admin user: {e}", exc_info=True)
|
||||||
|
|
||||||
chat_config = ConversationAdapters.get_default_conversation_config()
|
chat_config = ConversationAdapters.get_default_chat_model()
|
||||||
if admin_user is None and chat_config is None:
|
if admin_user is None and chat_config is None:
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -13,7 +13,7 @@ from khoj.configure import (
|
||||||
)
|
)
|
||||||
from khoj.database.models import (
|
from khoj.database.models import (
|
||||||
Agent,
|
Agent,
|
||||||
ChatModelOptions,
|
ChatModel,
|
||||||
GithubConfig,
|
GithubConfig,
|
||||||
GithubRepoConfig,
|
GithubRepoConfig,
|
||||||
KhojApiUser,
|
KhojApiUser,
|
||||||
|
@ -35,7 +35,7 @@ from khoj.utils.helpers import resolve_absolute_path
|
||||||
from khoj.utils.rawconfig import ContentConfig, ImageSearchConfig, SearchConfig
|
from khoj.utils.rawconfig import ContentConfig, ImageSearchConfig, SearchConfig
|
||||||
from tests.helpers import (
|
from tests.helpers import (
|
||||||
AiModelApiFactory,
|
AiModelApiFactory,
|
||||||
ChatModelOptionsFactory,
|
ChatModelFactory,
|
||||||
ProcessLockFactory,
|
ProcessLockFactory,
|
||||||
SubscriptionFactory,
|
SubscriptionFactory,
|
||||||
UserConversationProcessorConfigFactory,
|
UserConversationProcessorConfigFactory,
|
||||||
|
@ -184,14 +184,14 @@ def api_user4(default_user4):
|
||||||
@pytest.mark.django_db
|
@pytest.mark.django_db
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def default_openai_chat_model_option():
|
def default_openai_chat_model_option():
|
||||||
chat_model = ChatModelOptionsFactory(chat_model="gpt-4o-mini", model_type="openai")
|
chat_model = ChatModelFactory(name="gpt-4o-mini", model_type="openai")
|
||||||
return chat_model
|
return chat_model
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.django_db
|
@pytest.mark.django_db
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def offline_agent():
|
def offline_agent():
|
||||||
chat_model = ChatModelOptionsFactory()
|
chat_model = ChatModelFactory()
|
||||||
return Agent.objects.create(
|
return Agent.objects.create(
|
||||||
name="Accountant",
|
name="Accountant",
|
||||||
chat_model=chat_model,
|
chat_model=chat_model,
|
||||||
|
@ -202,7 +202,7 @@ def offline_agent():
|
||||||
@pytest.mark.django_db
|
@pytest.mark.django_db
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def openai_agent():
|
def openai_agent():
|
||||||
chat_model = ChatModelOptionsFactory(chat_model="gpt-4o-mini", model_type="openai")
|
chat_model = ChatModelFactory(name="gpt-4o-mini", model_type="openai")
|
||||||
return Agent.objects.create(
|
return Agent.objects.create(
|
||||||
name="Accountant",
|
name="Accountant",
|
||||||
chat_model=chat_model,
|
chat_model=chat_model,
|
||||||
|
@ -311,13 +311,13 @@ def chat_client_builder(search_config, user, index_content=True, require_auth=Fa
|
||||||
|
|
||||||
# Initialize Processor from Config
|
# Initialize Processor from Config
|
||||||
chat_provider = get_chat_provider()
|
chat_provider = get_chat_provider()
|
||||||
online_chat_model: ChatModelOptionsFactory = None
|
online_chat_model: ChatModelFactory = None
|
||||||
if chat_provider == ChatModelOptions.ModelType.OPENAI:
|
if chat_provider == ChatModel.ModelType.OPENAI:
|
||||||
online_chat_model = ChatModelOptionsFactory(chat_model="gpt-4o-mini", model_type="openai")
|
online_chat_model = ChatModelFactory(name="gpt-4o-mini", model_type="openai")
|
||||||
elif chat_provider == ChatModelOptions.ModelType.GOOGLE:
|
elif chat_provider == ChatModel.ModelType.GOOGLE:
|
||||||
online_chat_model = ChatModelOptionsFactory(chat_model="gemini-1.5-flash", model_type="google")
|
online_chat_model = ChatModelFactory(name="gemini-1.5-flash", model_type="google")
|
||||||
elif chat_provider == ChatModelOptions.ModelType.ANTHROPIC:
|
elif chat_provider == ChatModel.ModelType.ANTHROPIC:
|
||||||
online_chat_model = ChatModelOptionsFactory(chat_model="claude-3-5-haiku-20241022", model_type="anthropic")
|
online_chat_model = ChatModelFactory(name="claude-3-5-haiku-20241022", model_type="anthropic")
|
||||||
if online_chat_model:
|
if online_chat_model:
|
||||||
online_chat_model.ai_model_api = AiModelApiFactory(api_key=get_chat_api_key(chat_provider))
|
online_chat_model.ai_model_api = AiModelApiFactory(api_key=get_chat_api_key(chat_provider))
|
||||||
UserConversationProcessorConfigFactory(user=user, setting=online_chat_model)
|
UserConversationProcessorConfigFactory(user=user, setting=online_chat_model)
|
||||||
|
@ -394,8 +394,8 @@ def client_offline_chat(search_config: SearchConfig, default_user2: KhojUser):
|
||||||
configure_content(default_user2, all_files)
|
configure_content(default_user2, all_files)
|
||||||
|
|
||||||
# Initialize Processor from Config
|
# Initialize Processor from Config
|
||||||
ChatModelOptionsFactory(
|
ChatModelFactory(
|
||||||
chat_model="bartowski/Meta-Llama-3.1-3B-Instruct-GGUF",
|
name="bartowski/Meta-Llama-3.1-3B-Instruct-GGUF",
|
||||||
tokenizer=None,
|
tokenizer=None,
|
||||||
max_prompt_size=None,
|
max_prompt_size=None,
|
||||||
model_type="offline",
|
model_type="offline",
|
||||||
|
|
|
@ -6,7 +6,7 @@ from django.utils.timezone import make_aware
|
||||||
|
|
||||||
from khoj.database.models import (
|
from khoj.database.models import (
|
||||||
AiModelApi,
|
AiModelApi,
|
||||||
ChatModelOptions,
|
ChatModel,
|
||||||
Conversation,
|
Conversation,
|
||||||
KhojApiUser,
|
KhojApiUser,
|
||||||
KhojUser,
|
KhojUser,
|
||||||
|
@ -18,27 +18,27 @@ from khoj.database.models import (
|
||||||
from khoj.processor.conversation.utils import message_to_log
|
from khoj.processor.conversation.utils import message_to_log
|
||||||
|
|
||||||
|
|
||||||
def get_chat_provider(default: ChatModelOptions.ModelType | None = ChatModelOptions.ModelType.OFFLINE):
|
def get_chat_provider(default: ChatModel.ModelType | None = ChatModel.ModelType.OFFLINE):
|
||||||
provider = os.getenv("KHOJ_TEST_CHAT_PROVIDER")
|
provider = os.getenv("KHOJ_TEST_CHAT_PROVIDER")
|
||||||
if provider and provider in ChatModelOptions.ModelType:
|
if provider and provider in ChatModel.ModelType:
|
||||||
return ChatModelOptions.ModelType(provider)
|
return ChatModel.ModelType(provider)
|
||||||
elif os.getenv("OPENAI_API_KEY"):
|
elif os.getenv("OPENAI_API_KEY"):
|
||||||
return ChatModelOptions.ModelType.OPENAI
|
return ChatModel.ModelType.OPENAI
|
||||||
elif os.getenv("GEMINI_API_KEY"):
|
elif os.getenv("GEMINI_API_KEY"):
|
||||||
return ChatModelOptions.ModelType.GOOGLE
|
return ChatModel.ModelType.GOOGLE
|
||||||
elif os.getenv("ANTHROPIC_API_KEY"):
|
elif os.getenv("ANTHROPIC_API_KEY"):
|
||||||
return ChatModelOptions.ModelType.ANTHROPIC
|
return ChatModel.ModelType.ANTHROPIC
|
||||||
else:
|
else:
|
||||||
return default
|
return default
|
||||||
|
|
||||||
|
|
||||||
def get_chat_api_key(provider: ChatModelOptions.ModelType = None):
|
def get_chat_api_key(provider: ChatModel.ModelType = None):
|
||||||
provider = provider or get_chat_provider()
|
provider = provider or get_chat_provider()
|
||||||
if provider == ChatModelOptions.ModelType.OPENAI:
|
if provider == ChatModel.ModelType.OPENAI:
|
||||||
return os.getenv("OPENAI_API_KEY")
|
return os.getenv("OPENAI_API_KEY")
|
||||||
elif provider == ChatModelOptions.ModelType.GOOGLE:
|
elif provider == ChatModel.ModelType.GOOGLE:
|
||||||
return os.getenv("GEMINI_API_KEY")
|
return os.getenv("GEMINI_API_KEY")
|
||||||
elif provider == ChatModelOptions.ModelType.ANTHROPIC:
|
elif provider == ChatModel.ModelType.ANTHROPIC:
|
||||||
return os.getenv("ANTHROPIC_API_KEY")
|
return os.getenv("ANTHROPIC_API_KEY")
|
||||||
else:
|
else:
|
||||||
return os.getenv("OPENAI_API_KEY") or os.getenv("GEMINI_API_KEY") or os.getenv("ANTHROPIC_API_KEY")
|
return os.getenv("OPENAI_API_KEY") or os.getenv("GEMINI_API_KEY") or os.getenv("ANTHROPIC_API_KEY")
|
||||||
|
@ -83,13 +83,13 @@ class AiModelApiFactory(factory.django.DjangoModelFactory):
|
||||||
api_key = get_chat_api_key()
|
api_key = get_chat_api_key()
|
||||||
|
|
||||||
|
|
||||||
class ChatModelOptionsFactory(factory.django.DjangoModelFactory):
|
class ChatModelFactory(factory.django.DjangoModelFactory):
|
||||||
class Meta:
|
class Meta:
|
||||||
model = ChatModelOptions
|
model = ChatModel
|
||||||
|
|
||||||
max_prompt_size = 20000
|
max_prompt_size = 20000
|
||||||
tokenizer = None
|
tokenizer = None
|
||||||
chat_model = "bartowski/Meta-Llama-3.2-3B-Instruct-GGUF"
|
name = "bartowski/Meta-Llama-3.2-3B-Instruct-GGUF"
|
||||||
model_type = get_chat_provider()
|
model_type = get_chat_provider()
|
||||||
ai_model_api = factory.LazyAttribute(lambda obj: AiModelApiFactory() if get_chat_api_key() else None)
|
ai_model_api = factory.LazyAttribute(lambda obj: AiModelApiFactory() if get_chat_api_key() else None)
|
||||||
|
|
||||||
|
@ -99,7 +99,7 @@ class UserConversationProcessorConfigFactory(factory.django.DjangoModelFactory):
|
||||||
model = UserConversationConfig
|
model = UserConversationConfig
|
||||||
|
|
||||||
user = factory.SubFactory(UserFactory)
|
user = factory.SubFactory(UserFactory)
|
||||||
setting = factory.SubFactory(ChatModelOptionsFactory)
|
setting = factory.SubFactory(ChatModelFactory)
|
||||||
|
|
||||||
|
|
||||||
class ConversationFactory(factory.django.DjangoModelFactory):
|
class ConversationFactory(factory.django.DjangoModelFactory):
|
||||||
|
|
|
@ -5,14 +5,14 @@ import pytest
|
||||||
from asgiref.sync import sync_to_async
|
from asgiref.sync import sync_to_async
|
||||||
|
|
||||||
from khoj.database.adapters import AgentAdapters
|
from khoj.database.adapters import AgentAdapters
|
||||||
from khoj.database.models import Agent, ChatModelOptions, Entry, KhojUser
|
from khoj.database.models import Agent, ChatModel, Entry, KhojUser
|
||||||
from khoj.routers.api import execute_search
|
from khoj.routers.api import execute_search
|
||||||
from khoj.utils.helpers import get_absolute_path
|
from khoj.utils.helpers import get_absolute_path
|
||||||
from tests.helpers import ChatModelOptionsFactory
|
from tests.helpers import ChatModelFactory
|
||||||
|
|
||||||
|
|
||||||
def test_create_default_agent(default_user: KhojUser):
|
def test_create_default_agent(default_user: KhojUser):
|
||||||
ChatModelOptionsFactory()
|
ChatModelFactory()
|
||||||
|
|
||||||
agent = AgentAdapters.create_default_agent(default_user)
|
agent = AgentAdapters.create_default_agent(default_user)
|
||||||
assert agent is not None
|
assert agent is not None
|
||||||
|
@ -24,7 +24,7 @@ def test_create_default_agent(default_user: KhojUser):
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
@pytest.mark.django_db(transaction=True)
|
@pytest.mark.django_db(transaction=True)
|
||||||
async def test_create_or_update_agent(default_user: KhojUser, default_openai_chat_model_option: ChatModelOptions):
|
async def test_create_or_update_agent(default_user: KhojUser, default_openai_chat_model_option: ChatModel):
|
||||||
new_agent = await AgentAdapters.aupdate_agent(
|
new_agent = await AgentAdapters.aupdate_agent(
|
||||||
default_user,
|
default_user,
|
||||||
"Test Agent",
|
"Test Agent",
|
||||||
|
@ -32,7 +32,7 @@ async def test_create_or_update_agent(default_user: KhojUser, default_openai_cha
|
||||||
Agent.PrivacyLevel.PRIVATE,
|
Agent.PrivacyLevel.PRIVATE,
|
||||||
"icon",
|
"icon",
|
||||||
"color",
|
"color",
|
||||||
default_openai_chat_model_option.chat_model,
|
default_openai_chat_model_option.name,
|
||||||
[],
|
[],
|
||||||
[],
|
[],
|
||||||
[],
|
[],
|
||||||
|
@ -46,7 +46,7 @@ async def test_create_or_update_agent(default_user: KhojUser, default_openai_cha
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
@pytest.mark.django_db(transaction=True)
|
@pytest.mark.django_db(transaction=True)
|
||||||
async def test_create_or_update_agent_with_knowledge_base(
|
async def test_create_or_update_agent_with_knowledge_base(
|
||||||
default_user2: KhojUser, default_openai_chat_model_option: ChatModelOptions, chat_client
|
default_user2: KhojUser, default_openai_chat_model_option: ChatModel, chat_client
|
||||||
):
|
):
|
||||||
full_filename = get_absolute_path("tests/data/markdown/having_kids.markdown")
|
full_filename = get_absolute_path("tests/data/markdown/having_kids.markdown")
|
||||||
new_agent = await AgentAdapters.aupdate_agent(
|
new_agent = await AgentAdapters.aupdate_agent(
|
||||||
|
@ -56,7 +56,7 @@ async def test_create_or_update_agent_with_knowledge_base(
|
||||||
Agent.PrivacyLevel.PRIVATE,
|
Agent.PrivacyLevel.PRIVATE,
|
||||||
"icon",
|
"icon",
|
||||||
"color",
|
"color",
|
||||||
default_openai_chat_model_option.chat_model,
|
default_openai_chat_model_option.name,
|
||||||
[full_filename],
|
[full_filename],
|
||||||
[],
|
[],
|
||||||
[],
|
[],
|
||||||
|
@ -78,7 +78,7 @@ async def test_create_or_update_agent_with_knowledge_base(
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
@pytest.mark.django_db(transaction=True)
|
@pytest.mark.django_db(transaction=True)
|
||||||
async def test_create_or_update_agent_with_knowledge_base_and_search(
|
async def test_create_or_update_agent_with_knowledge_base_and_search(
|
||||||
default_user2: KhojUser, default_openai_chat_model_option: ChatModelOptions, chat_client
|
default_user2: KhojUser, default_openai_chat_model_option: ChatModel, chat_client
|
||||||
):
|
):
|
||||||
full_filename = get_absolute_path("tests/data/markdown/having_kids.markdown")
|
full_filename = get_absolute_path("tests/data/markdown/having_kids.markdown")
|
||||||
new_agent = await AgentAdapters.aupdate_agent(
|
new_agent = await AgentAdapters.aupdate_agent(
|
||||||
|
@ -88,7 +88,7 @@ async def test_create_or_update_agent_with_knowledge_base_and_search(
|
||||||
Agent.PrivacyLevel.PRIVATE,
|
Agent.PrivacyLevel.PRIVATE,
|
||||||
"icon",
|
"icon",
|
||||||
"color",
|
"color",
|
||||||
default_openai_chat_model_option.chat_model,
|
default_openai_chat_model_option.name,
|
||||||
[full_filename],
|
[full_filename],
|
||||||
[],
|
[],
|
||||||
[],
|
[],
|
||||||
|
@ -102,7 +102,7 @@ async def test_create_or_update_agent_with_knowledge_base_and_search(
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
@pytest.mark.django_db(transaction=True)
|
@pytest.mark.django_db(transaction=True)
|
||||||
async def test_agent_with_knowledge_base_and_search_not_creator(
|
async def test_agent_with_knowledge_base_and_search_not_creator(
|
||||||
default_user2: KhojUser, default_openai_chat_model_option: ChatModelOptions, chat_client, default_user3: KhojUser
|
default_user2: KhojUser, default_openai_chat_model_option: ChatModel, chat_client, default_user3: KhojUser
|
||||||
):
|
):
|
||||||
full_filename = get_absolute_path("tests/data/markdown/having_kids.markdown")
|
full_filename = get_absolute_path("tests/data/markdown/having_kids.markdown")
|
||||||
new_agent = await AgentAdapters.aupdate_agent(
|
new_agent = await AgentAdapters.aupdate_agent(
|
||||||
|
@ -112,7 +112,7 @@ async def test_agent_with_knowledge_base_and_search_not_creator(
|
||||||
Agent.PrivacyLevel.PUBLIC,
|
Agent.PrivacyLevel.PUBLIC,
|
||||||
"icon",
|
"icon",
|
||||||
"color",
|
"color",
|
||||||
default_openai_chat_model_option.chat_model,
|
default_openai_chat_model_option.name,
|
||||||
[full_filename],
|
[full_filename],
|
||||||
[],
|
[],
|
||||||
[],
|
[],
|
||||||
|
@ -126,7 +126,7 @@ async def test_agent_with_knowledge_base_and_search_not_creator(
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
@pytest.mark.django_db(transaction=True)
|
@pytest.mark.django_db(transaction=True)
|
||||||
async def test_agent_with_knowledge_base_and_search_not_creator_and_private(
|
async def test_agent_with_knowledge_base_and_search_not_creator_and_private(
|
||||||
default_user2: KhojUser, default_openai_chat_model_option: ChatModelOptions, chat_client, default_user3: KhojUser
|
default_user2: KhojUser, default_openai_chat_model_option: ChatModel, chat_client, default_user3: KhojUser
|
||||||
):
|
):
|
||||||
full_filename = get_absolute_path("tests/data/markdown/having_kids.markdown")
|
full_filename = get_absolute_path("tests/data/markdown/having_kids.markdown")
|
||||||
new_agent = await AgentAdapters.aupdate_agent(
|
new_agent = await AgentAdapters.aupdate_agent(
|
||||||
|
@ -136,7 +136,7 @@ async def test_agent_with_knowledge_base_and_search_not_creator_and_private(
|
||||||
Agent.PrivacyLevel.PRIVATE,
|
Agent.PrivacyLevel.PRIVATE,
|
||||||
"icon",
|
"icon",
|
||||||
"color",
|
"color",
|
||||||
default_openai_chat_model_option.chat_model,
|
default_openai_chat_model_option.name,
|
||||||
[full_filename],
|
[full_filename],
|
||||||
[],
|
[],
|
||||||
[],
|
[],
|
||||||
|
@ -150,7 +150,7 @@ async def test_agent_with_knowledge_base_and_search_not_creator_and_private(
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
@pytest.mark.django_db(transaction=True)
|
@pytest.mark.django_db(transaction=True)
|
||||||
async def test_agent_with_knowledge_base_and_search_not_creator_and_private_accessible_to_none(
|
async def test_agent_with_knowledge_base_and_search_not_creator_and_private_accessible_to_none(
|
||||||
default_user2: KhojUser, default_openai_chat_model_option: ChatModelOptions, chat_client
|
default_user2: KhojUser, default_openai_chat_model_option: ChatModel, chat_client
|
||||||
):
|
):
|
||||||
full_filename = get_absolute_path("tests/data/markdown/having_kids.markdown")
|
full_filename = get_absolute_path("tests/data/markdown/having_kids.markdown")
|
||||||
new_agent = await AgentAdapters.aupdate_agent(
|
new_agent = await AgentAdapters.aupdate_agent(
|
||||||
|
@ -160,7 +160,7 @@ async def test_agent_with_knowledge_base_and_search_not_creator_and_private_acce
|
||||||
Agent.PrivacyLevel.PRIVATE,
|
Agent.PrivacyLevel.PRIVATE,
|
||||||
"icon",
|
"icon",
|
||||||
"color",
|
"color",
|
||||||
default_openai_chat_model_option.chat_model,
|
default_openai_chat_model_option.name,
|
||||||
[full_filename],
|
[full_filename],
|
||||||
[],
|
[],
|
||||||
[],
|
[],
|
||||||
|
@ -174,7 +174,7 @@ async def test_agent_with_knowledge_base_and_search_not_creator_and_private_acce
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
@pytest.mark.django_db(transaction=True)
|
@pytest.mark.django_db(transaction=True)
|
||||||
async def test_multiple_agents_with_knowledge_base_and_users(
|
async def test_multiple_agents_with_knowledge_base_and_users(
|
||||||
default_user2: KhojUser, default_openai_chat_model_option: ChatModelOptions, chat_client, default_user3: KhojUser
|
default_user2: KhojUser, default_openai_chat_model_option: ChatModel, chat_client, default_user3: KhojUser
|
||||||
):
|
):
|
||||||
full_filename = get_absolute_path("tests/data/markdown/having_kids.markdown")
|
full_filename = get_absolute_path("tests/data/markdown/having_kids.markdown")
|
||||||
new_agent = await AgentAdapters.aupdate_agent(
|
new_agent = await AgentAdapters.aupdate_agent(
|
||||||
|
@ -184,7 +184,7 @@ async def test_multiple_agents_with_knowledge_base_and_users(
|
||||||
Agent.PrivacyLevel.PUBLIC,
|
Agent.PrivacyLevel.PUBLIC,
|
||||||
"icon",
|
"icon",
|
||||||
"color",
|
"color",
|
||||||
default_openai_chat_model_option.chat_model,
|
default_openai_chat_model_option.name,
|
||||||
[full_filename],
|
[full_filename],
|
||||||
[],
|
[],
|
||||||
[],
|
[],
|
||||||
|
@ -198,7 +198,7 @@ async def test_multiple_agents_with_knowledge_base_and_users(
|
||||||
Agent.PrivacyLevel.PUBLIC,
|
Agent.PrivacyLevel.PUBLIC,
|
||||||
"icon",
|
"icon",
|
||||||
"color",
|
"color",
|
||||||
default_openai_chat_model_option.chat_model,
|
default_openai_chat_model_option.name,
|
||||||
[full_filename2],
|
[full_filename2],
|
||||||
[],
|
[],
|
||||||
[],
|
[],
|
||||||
|
|
|
@ -2,12 +2,12 @@ from datetime import datetime
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from khoj.database.models import ChatModelOptions
|
from khoj.database.models import ChatModel
|
||||||
from khoj.routers.helpers import aget_data_sources_and_output_format
|
from khoj.routers.helpers import aget_data_sources_and_output_format
|
||||||
from khoj.utils.helpers import ConversationCommand
|
from khoj.utils.helpers import ConversationCommand
|
||||||
from tests.helpers import ConversationFactory, generate_chat_history, get_chat_provider
|
from tests.helpers import ConversationFactory, generate_chat_history, get_chat_provider
|
||||||
|
|
||||||
SKIP_TESTS = get_chat_provider(default=None) != ChatModelOptions.ModelType.OFFLINE
|
SKIP_TESTS = get_chat_provider(default=None) != ChatModel.ModelType.OFFLINE
|
||||||
pytestmark = pytest.mark.skipif(
|
pytestmark = pytest.mark.skipif(
|
||||||
SKIP_TESTS,
|
SKIP_TESTS,
|
||||||
reason="Disable in CI to avoid long test runs.",
|
reason="Disable in CI to avoid long test runs.",
|
||||||
|
|
|
@ -4,12 +4,12 @@ import pytest
|
||||||
from faker import Faker
|
from faker import Faker
|
||||||
from freezegun import freeze_time
|
from freezegun import freeze_time
|
||||||
|
|
||||||
from khoj.database.models import Agent, ChatModelOptions, Entry, KhojUser
|
from khoj.database.models import Agent, ChatModel, Entry, KhojUser
|
||||||
from khoj.processor.conversation import prompts
|
from khoj.processor.conversation import prompts
|
||||||
from khoj.processor.conversation.utils import message_to_log
|
from khoj.processor.conversation.utils import message_to_log
|
||||||
from tests.helpers import ConversationFactory, get_chat_provider
|
from tests.helpers import ConversationFactory, get_chat_provider
|
||||||
|
|
||||||
SKIP_TESTS = get_chat_provider(default=None) != ChatModelOptions.ModelType.OFFLINE
|
SKIP_TESTS = get_chat_provider(default=None) != ChatModel.ModelType.OFFLINE
|
||||||
pytestmark = pytest.mark.skipif(
|
pytestmark = pytest.mark.skipif(
|
||||||
SKIP_TESTS,
|
SKIP_TESTS,
|
||||||
reason="Disable in CI to avoid long test runs.",
|
reason="Disable in CI to avoid long test runs.",
|
||||||
|
|
Loading…
Reference in a new issue