Merge branch 'master' of github.com:khoj-ai/khoj into features/new-sign-in-page

This commit is contained in:
sabaimran 2024-12-12 15:43:06 -08:00
commit dfc150c442
30 changed files with 412 additions and 340 deletions

View file

@ -113,7 +113,8 @@ jobs:
khoj --anonymous-mode --non-interactive & khoj --anonymous-mode --non-interactive &
# Start code sandbox # Start code sandbox
npm run dev --prefix terrarium & npm install -g pm2
npm run ci --prefix terrarium
# Wait for server to be ready # Wait for server to be ready
timeout=120 timeout=120

View file

@ -25,7 +25,7 @@ Using LiteLLM with Khoj makes it possible to turn any LLM behind an API into you
- Name: `proxy-name` - Name: `proxy-name`
- Api Key: `any string` - Api Key: `any string`
- Api Base Url: **URL of your Openai Proxy API** - Api Base Url: **URL of your Openai Proxy API**
4. Create a new [Chat Model Option](http://localhost:42110/server/admin/database/chatmodeloptions/add) on your Khoj admin panel. 4. Create a new [Chat Model](http://localhost:42110/server/admin/database/chatmodel/add) on your Khoj admin panel.
- Name: `llama3.1` (replace with the name of your local model) - Name: `llama3.1` (replace with the name of your local model)
- Model Type: `Openai` - Model Type: `Openai`
- Openai Config: `<the proxy config you created in step 3>` - Openai Config: `<the proxy config you created in step 3>`

View file

@ -18,7 +18,7 @@ LM Studio can expose an [OpenAI API compatible server](https://lmstudio.ai/docs/
- Name: `proxy-name` - Name: `proxy-name`
- Api Key: `any string` - Api Key: `any string`
- Api Base Url: `http://localhost:1234/v1/` (default for LMStudio) - Api Base Url: `http://localhost:1234/v1/` (default for LMStudio)
4. Create a new [Chat Model Option](http://localhost:42110/server/admin/database/chatmodeloptions/add) on your Khoj admin panel. 4. Create a new [Chat Model](http://localhost:42110/server/admin/database/chatmodel/add) on your Khoj admin panel.
- Name: `llama3.1` (replace with the name of your local model) - Name: `llama3.1` (replace with the name of your local model)
- Model Type: `Openai` - Model Type: `Openai`
- Openai Config: `<the proxy config you created in step 3>` - Openai Config: `<the proxy config you created in step 3>`

View file

@ -64,7 +64,7 @@ Restart your Khoj server after first run or update to the settings below to ensu
- Name: `ollama` - Name: `ollama`
- Api Key: `any string` - Api Key: `any string`
- Api Base Url: `http://localhost:11434/v1/` (default for Ollama) - Api Base Url: `http://localhost:11434/v1/` (default for Ollama)
4. Create a new [Chat Model Option](http://localhost:42110/server/admin/database/chatmodeloptions/add) on your Khoj admin panel. 4. Create a new [Chat Model](http://localhost:42110/server/admin/database/chatmodel/add) on your Khoj admin panel.
- Name: `llama3.1` (replace with the name of your local model) - Name: `llama3.1` (replace with the name of your local model)
- Model Type: `Openai` - Model Type: `Openai`
- Openai Config: `<the ollama config you created in step 3>` - Openai Config: `<the ollama config you created in step 3>`

View file

@ -25,7 +25,7 @@ For specific integrations, see our [Ollama](/advanced/ollama), [LMStudio](/advan
- Name: `any name` - Name: `any name`
- Api Key: `any string` - Api Key: `any string`
- Api Base Url: **URL of your Openai Proxy API** - Api Base Url: **URL of your Openai Proxy API**
3. Create a new [Chat Model Option](http://localhost:42110/server/admin/database/chatmodeloptions/add) on your Khoj admin panel. 3. Create a new [Chat Model](http://localhost:42110/server/admin/database/chatmodel/add) on your Khoj admin panel.
- Name: `llama3` (replace with the name of your local model) - Name: `llama3` (replace with the name of your local model)
- Model Type: `Openai` - Model Type: `Openai`
- Openai Config: `<the proxy config you created in step 2>` - Openai Config: `<the proxy config you created in step 2>`

View file

@ -17,7 +17,7 @@ You have a couple of image generation options.
We support most state of the art image generation models, including Ideogram, Flux, and Stable Diffusion. These will run using [Replicate](https://replicate.com). Here's how to set them up: We support most state of the art image generation models, including Ideogram, Flux, and Stable Diffusion. These will run using [Replicate](https://replicate.com). Here's how to set them up:
1. Get a Replicate API key [here](https://replicate.com/account/api-tokens). 1. Get a Replicate API key [here](https://replicate.com/account/api-tokens).
1. Create a new [Text to Image Model](https://app.khoj.dev/server/admin/database/texttoimagemodelconfig/). Set the `type` to `Replicate`. Use any of the model names you see [on this list](https://replicate.com/pricing#image-models). 1. Create a new [Text to Image Model](http://localhost:42110/server/admin/database/texttoimagemodelconfig/). Set the `type` to `Replicate`. Use any of the model names you see [on this list](https://replicate.com/pricing#image-models).
### OpenAI ### OpenAI

View file

@ -307,7 +307,7 @@ Using Ollama? See the [Ollama Integration](/advanced/ollama) section for more cu
- Give the configuration a friendly name like `OpenAI` - Give the configuration a friendly name like `OpenAI`
- (Optional) Set the API base URL. It is only relevant if you're using another OpenAI-compatible proxy server like [Ollama](/advanced/ollama) or [LMStudio](/advanced/lmstudio).<br /> - (Optional) Set the API base URL. It is only relevant if you're using another OpenAI-compatible proxy server like [Ollama](/advanced/ollama) or [LMStudio](/advanced/lmstudio).<br />
![example configuration for ai model api](/img/example_openai_processor_config.png) ![example configuration for ai model api](/img/example_openai_processor_config.png)
2. Create a new [chat model options](http://localhost:42110/server/admin/database/chatmodeloptions/add) 2. Create a new [chat model](http://localhost:42110/server/admin/database/chatmodel/add)
- Set the `chat-model` field to an [OpenAI chat model](https://platform.openai.com/docs/models). Example: `gpt-4o`. - Set the `chat-model` field to an [OpenAI chat model](https://platform.openai.com/docs/models). Example: `gpt-4o`.
- Make sure to set the `model-type` field to `OpenAI`. - Make sure to set the `model-type` field to `OpenAI`.
- If your model supports vision, set the `vision enabled` field to `true`. This is currently only supported for OpenAI models with vision capabilities. - If your model supports vision, set the `vision enabled` field to `true`. This is currently only supported for OpenAI models with vision capabilities.
@ -318,7 +318,7 @@ Using Ollama? See the [Ollama Integration](/advanced/ollama) section for more cu
1. Create a new [AI Model API](http://localhost:42110/server/admin/database/aimodelapi/add) in the server admin settings. 1. Create a new [AI Model API](http://localhost:42110/server/admin/database/aimodelapi/add) in the server admin settings.
- Add your [Anthropic API key](https://console.anthropic.com/account/keys) - Add your [Anthropic API key](https://console.anthropic.com/account/keys)
- Give the configuration a friendly name like `Anthropic`. Do not configure the API base url. - Give the configuration a friendly name like `Anthropic`. Do not configure the API base url.
2. Create a new [chat model options](http://localhost:42110/server/admin/database/chatmodeloptions/add) 2. Create a new [chat model](http://localhost:42110/server/admin/database/chatmodel/add)
- Set the `chat-model` field to an [Anthropic chat model](https://docs.anthropic.com/en/docs/about-claude/models#model-names). Example: `claude-3-5-sonnet-20240620`. - Set the `chat-model` field to an [Anthropic chat model](https://docs.anthropic.com/en/docs/about-claude/models#model-names). Example: `claude-3-5-sonnet-20240620`.
- Set the `model-type` field to `Anthropic`. - Set the `model-type` field to `Anthropic`.
- Set the `ai model api` field to the Anthropic AI Model API you created in step 1. - Set the `ai model api` field to the Anthropic AI Model API you created in step 1.
@ -327,7 +327,7 @@ Using Ollama? See the [Ollama Integration](/advanced/ollama) section for more cu
1. Create a new [AI Model API](http://localhost:42110/server/admin/database/aimodelapi/add) in the server admin settings. 1. Create a new [AI Model API](http://localhost:42110/server/admin/database/aimodelapi/add) in the server admin settings.
- Add your [Gemini API key](https://aistudio.google.com/app/apikey) - Add your [Gemini API key](https://aistudio.google.com/app/apikey)
- Give the configuration a friendly name like `Gemini`. Do not configure the API base url. - Give the configuration a friendly name like `Gemini`. Do not configure the API base url.
2. Create a new [chat model options](http://localhost:42110/server/admin/database/chatmodeloptions/add) 2. Create a new [chat model](http://localhost:42110/server/admin/database/chatmodel/add)
- Set the `chat-model` field to a [Google Gemini chat model](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#gemini-models). Example: `gemini-1.5-flash`. - Set the `chat-model` field to a [Google Gemini chat model](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#gemini-models). Example: `gemini-1.5-flash`.
- Set the `model-type` field to `Gemini`. - Set the `model-type` field to `Gemini`.
- Set the `ai model api` field to the Gemini AI Model API you created in step 1. - Set the `ai model api` field to the Gemini AI Model API you created in step 1.
@ -343,7 +343,7 @@ Offline chat stays completely private and can work without internet using any op
::: :::
1. Get the name of your preferred chat model from [HuggingFace](https://huggingface.co/models?pipeline_tag=text-generation&library=gguf). *Most GGUF format chat models are supported*. 1. Get the name of your preferred chat model from [HuggingFace](https://huggingface.co/models?pipeline_tag=text-generation&library=gguf). *Most GGUF format chat models are supported*.
2. Open the [create chat model page](http://localhost:42110/server/admin/database/chatmodeloptions/add/) on the admin panel 2. Open the [create chat model page](http://localhost:42110/server/admin/database/chatmodel/add/) on the admin panel
3. Set the `chat-model` field to the name of your preferred chat model 3. Set the `chat-model` field to the name of your preferred chat model
- Make sure the `model-type` is set to `Offline` - Make sure the `model-type` is set to `Offline`
4. Set the newly added chat model as your preferred model in your [User chat settings](http://localhost:42110/settings) and [Server chat settings](http://localhost:42110/server/admin/database/serverchatsettings/). 4. Set the newly added chat model as your preferred model in your [User chat settings](http://localhost:42110/settings) and [Server chat settings](http://localhost:42110/server/admin/database/serverchatsettings/).

View file

@ -36,7 +36,7 @@ from torch import Tensor
from khoj.database.models import ( from khoj.database.models import (
Agent, Agent,
AiModelApi, AiModelApi,
ChatModelOptions, ChatModel,
ClientApplication, ClientApplication,
Conversation, Conversation,
Entry, Entry,
@ -736,8 +736,8 @@ class AgentAdapters:
@staticmethod @staticmethod
def create_default_agent(user: KhojUser): def create_default_agent(user: KhojUser):
default_conversation_config = ConversationAdapters.get_default_conversation_config(user) default_chat_model = ConversationAdapters.get_default_chat_model(user)
if default_conversation_config is None: if default_chat_model is None:
logger.info("No default conversation config found, skipping default agent creation") logger.info("No default conversation config found, skipping default agent creation")
return None return None
default_personality = prompts.personality.format(current_date="placeholder", day_of_week="placeholder") default_personality = prompts.personality.format(current_date="placeholder", day_of_week="placeholder")
@ -746,7 +746,7 @@ class AgentAdapters:
if agent: if agent:
agent.personality = default_personality agent.personality = default_personality
agent.chat_model = default_conversation_config agent.chat_model = default_chat_model
agent.slug = AgentAdapters.DEFAULT_AGENT_SLUG agent.slug = AgentAdapters.DEFAULT_AGENT_SLUG
agent.name = AgentAdapters.DEFAULT_AGENT_NAME agent.name = AgentAdapters.DEFAULT_AGENT_NAME
agent.privacy_level = Agent.PrivacyLevel.PUBLIC agent.privacy_level = Agent.PrivacyLevel.PUBLIC
@ -760,7 +760,7 @@ class AgentAdapters:
name=AgentAdapters.DEFAULT_AGENT_NAME, name=AgentAdapters.DEFAULT_AGENT_NAME,
privacy_level=Agent.PrivacyLevel.PUBLIC, privacy_level=Agent.PrivacyLevel.PUBLIC,
managed_by_admin=True, managed_by_admin=True,
chat_model=default_conversation_config, chat_model=default_chat_model,
personality=default_personality, personality=default_personality,
slug=AgentAdapters.DEFAULT_AGENT_SLUG, slug=AgentAdapters.DEFAULT_AGENT_SLUG,
) )
@ -787,7 +787,7 @@ class AgentAdapters:
output_modes: List[str], output_modes: List[str],
slug: Optional[str] = None, slug: Optional[str] = None,
): ):
chat_model_option = await ChatModelOptions.objects.filter(chat_model=chat_model).afirst() chat_model_option = await ChatModel.objects.filter(name=chat_model).afirst()
# Slug will be None for new agents, which will trigger a new agent creation with a generated, immutable slug # Slug will be None for new agents, which will trigger a new agent creation with a generated, immutable slug
agent, created = await Agent.objects.filter(slug=slug, creator=user).aupdate_or_create( agent, created = await Agent.objects.filter(slug=slug, creator=user).aupdate_or_create(
@ -972,29 +972,29 @@ class ConversationAdapters:
@staticmethod @staticmethod
@require_valid_user @require_valid_user
def has_any_conversation_config(user: KhojUser): def has_any_chat_model(user: KhojUser):
return ChatModelOptions.objects.filter(user=user).exists() return ChatModel.objects.filter(user=user).exists()
@staticmethod @staticmethod
def get_all_conversation_configs(): def get_all_chat_models():
return ChatModelOptions.objects.all() return ChatModel.objects.all()
@staticmethod @staticmethod
async def aget_all_conversation_configs(): async def aget_all_chat_models():
return await sync_to_async(list)(ChatModelOptions.objects.prefetch_related("ai_model_api").all()) return await sync_to_async(list)(ChatModel.objects.prefetch_related("ai_model_api").all())
@staticmethod @staticmethod
def get_vision_enabled_config(): def get_vision_enabled_config():
conversation_configurations = ConversationAdapters.get_all_conversation_configs() chat_models = ConversationAdapters.get_all_chat_models()
for config in conversation_configurations: for config in chat_models:
if config.vision_enabled: if config.vision_enabled:
return config return config
return None return None
@staticmethod @staticmethod
async def aget_vision_enabled_config(): async def aget_vision_enabled_config():
conversation_configurations = await ConversationAdapters.aget_all_conversation_configs() chat_models = await ConversationAdapters.aget_all_chat_models()
for config in conversation_configurations: for config in chat_models:
if config.vision_enabled: if config.vision_enabled:
return config return config
return None return None
@ -1010,7 +1010,7 @@ class ConversationAdapters:
@staticmethod @staticmethod
@arequire_valid_user @arequire_valid_user
async def aset_user_conversation_processor(user: KhojUser, conversation_processor_config_id: int): async def aset_user_conversation_processor(user: KhojUser, conversation_processor_config_id: int):
config = await ChatModelOptions.objects.filter(id=conversation_processor_config_id).afirst() config = await ChatModel.objects.filter(id=conversation_processor_config_id).afirst()
if not config: if not config:
return None return None
new_config = await UserConversationConfig.objects.aupdate_or_create(user=user, defaults={"setting": config}) new_config = await UserConversationConfig.objects.aupdate_or_create(user=user, defaults={"setting": config})
@ -1026,24 +1026,24 @@ class ConversationAdapters:
return new_config return new_config
@staticmethod @staticmethod
def get_conversation_config(user: KhojUser): def get_chat_model(user: KhojUser):
subscribed = is_user_subscribed(user) subscribed = is_user_subscribed(user)
if not subscribed: if not subscribed:
return ConversationAdapters.get_default_conversation_config(user) return ConversationAdapters.get_default_chat_model(user)
config = UserConversationConfig.objects.filter(user=user).first() config = UserConversationConfig.objects.filter(user=user).first()
if config: if config:
return config.setting return config.setting
return ConversationAdapters.get_advanced_conversation_config(user) return ConversationAdapters.get_advanced_chat_model(user)
@staticmethod @staticmethod
async def aget_conversation_config(user: KhojUser): async def aget_chat_model(user: KhojUser):
subscribed = await ais_user_subscribed(user) subscribed = await ais_user_subscribed(user)
if not subscribed: if not subscribed:
return await ConversationAdapters.aget_default_conversation_config(user) return await ConversationAdapters.aget_default_chat_model(user)
config = await UserConversationConfig.objects.filter(user=user).prefetch_related("setting").afirst() config = await UserConversationConfig.objects.filter(user=user).prefetch_related("setting").afirst()
if config: if config:
return config.setting return config.setting
return ConversationAdapters.aget_advanced_conversation_config(user) return ConversationAdapters.aget_advanced_chat_model(user)
@staticmethod @staticmethod
async def aget_voice_model_config(user: KhojUser) -> Optional[VoiceModelOption]: async def aget_voice_model_config(user: KhojUser) -> Optional[VoiceModelOption]:
@ -1064,7 +1064,7 @@ class ConversationAdapters:
return VoiceModelOption.objects.first() return VoiceModelOption.objects.first()
@staticmethod @staticmethod
def get_default_conversation_config(user: KhojUser = None): def get_default_chat_model(user: KhojUser = None):
"""Get default conversation config. Prefer chat model by server admin > user > first created chat model""" """Get default conversation config. Prefer chat model by server admin > user > first created chat model"""
# Get the server chat settings # Get the server chat settings
server_chat_settings = ServerChatSettings.objects.first() server_chat_settings = ServerChatSettings.objects.first()
@ -1084,10 +1084,10 @@ class ConversationAdapters:
return user_chat_settings.setting return user_chat_settings.setting
# Get the first chat model if even the user chat settings are not set # Get the first chat model if even the user chat settings are not set
return ChatModelOptions.objects.filter().first() return ChatModel.objects.filter().first()
@staticmethod @staticmethod
async def aget_default_conversation_config(user: KhojUser = None): async def aget_default_chat_model(user: KhojUser = None):
"""Get default conversation config. Prefer chat model by server admin > user > first created chat model""" """Get default conversation config. Prefer chat model by server admin > user > first created chat model"""
# Get the server chat settings # Get the server chat settings
server_chat_settings: ServerChatSettings = ( server_chat_settings: ServerChatSettings = (
@ -1117,17 +1117,17 @@ class ConversationAdapters:
return user_chat_settings.setting return user_chat_settings.setting
# Get the first chat model if even the user chat settings are not set # Get the first chat model if even the user chat settings are not set
return await ChatModelOptions.objects.filter().prefetch_related("ai_model_api").afirst() return await ChatModel.objects.filter().prefetch_related("ai_model_api").afirst()
@staticmethod @staticmethod
def get_advanced_conversation_config(user: KhojUser): def get_advanced_chat_model(user: KhojUser):
server_chat_settings = ServerChatSettings.objects.first() server_chat_settings = ServerChatSettings.objects.first()
if server_chat_settings is not None and server_chat_settings.chat_advanced is not None: if server_chat_settings is not None and server_chat_settings.chat_advanced is not None:
return server_chat_settings.chat_advanced return server_chat_settings.chat_advanced
return ConversationAdapters.get_default_conversation_config(user) return ConversationAdapters.get_default_chat_model(user)
@staticmethod @staticmethod
async def aget_advanced_conversation_config(user: KhojUser = None): async def aget_advanced_chat_model(user: KhojUser = None):
server_chat_settings: ServerChatSettings = ( server_chat_settings: ServerChatSettings = (
await ServerChatSettings.objects.filter() await ServerChatSettings.objects.filter()
.prefetch_related("chat_advanced", "chat_advanced__ai_model_api") .prefetch_related("chat_advanced", "chat_advanced__ai_model_api")
@ -1135,7 +1135,7 @@ class ConversationAdapters:
) )
if server_chat_settings is not None and server_chat_settings.chat_advanced is not None: if server_chat_settings is not None and server_chat_settings.chat_advanced is not None:
return server_chat_settings.chat_advanced return server_chat_settings.chat_advanced
return await ConversationAdapters.aget_default_conversation_config(user) return await ConversationAdapters.aget_default_chat_model(user)
@staticmethod @staticmethod
async def aget_server_webscraper(): async def aget_server_webscraper():
@ -1247,16 +1247,16 @@ class ConversationAdapters:
@staticmethod @staticmethod
def get_conversation_processor_options(): def get_conversation_processor_options():
return ChatModelOptions.objects.all() return ChatModel.objects.all()
@staticmethod @staticmethod
def set_conversation_processor_config(user: KhojUser, new_config: ChatModelOptions): def set_user_chat_model(user: KhojUser, chat_model: ChatModel):
user_conversation_config, _ = UserConversationConfig.objects.get_or_create(user=user) user_conversation_config, _ = UserConversationConfig.objects.get_or_create(user=user)
user_conversation_config.setting = new_config user_conversation_config.setting = chat_model
user_conversation_config.save() user_conversation_config.save()
@staticmethod @staticmethod
async def aget_user_conversation_config(user: KhojUser): async def aget_user_chat_model(user: KhojUser):
config = ( config = (
await UserConversationConfig.objects.filter(user=user).prefetch_related("setting__ai_model_api").afirst() await UserConversationConfig.objects.filter(user=user).prefetch_related("setting__ai_model_api").afirst()
) )
@ -1288,33 +1288,33 @@ class ConversationAdapters:
return random.sample(all_questions, max_results) return random.sample(all_questions, max_results)
@staticmethod @staticmethod
def get_valid_conversation_config(user: KhojUser, conversation: Conversation): def get_valid_chat_model(user: KhojUser, conversation: Conversation):
agent: Agent = conversation.agent if AgentAdapters.get_default_agent() != conversation.agent else None agent: Agent = conversation.agent if AgentAdapters.get_default_agent() != conversation.agent else None
if agent and agent.chat_model: if agent and agent.chat_model:
conversation_config = conversation.agent.chat_model chat_model = conversation.agent.chat_model
else: else:
conversation_config = ConversationAdapters.get_conversation_config(user) chat_model = ConversationAdapters.get_chat_model(user)
if conversation_config is None: if chat_model is None:
conversation_config = ConversationAdapters.get_default_conversation_config() chat_model = ConversationAdapters.get_default_chat_model()
if conversation_config.model_type == ChatModelOptions.ModelType.OFFLINE: if chat_model.model_type == ChatModel.ModelType.OFFLINE:
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None: if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
chat_model = conversation_config.chat_model chat_model_name = chat_model.name
max_tokens = conversation_config.max_prompt_size max_tokens = chat_model.max_prompt_size
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens) state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model_name, max_tokens)
return conversation_config return chat_model
if ( if (
conversation_config.model_type chat_model.model_type
in [ in [
ChatModelOptions.ModelType.ANTHROPIC, ChatModel.ModelType.ANTHROPIC,
ChatModelOptions.ModelType.OPENAI, ChatModel.ModelType.OPENAI,
ChatModelOptions.ModelType.GOOGLE, ChatModel.ModelType.GOOGLE,
] ]
) and conversation_config.ai_model_api: ) and chat_model.ai_model_api:
return conversation_config return chat_model
else: else:
raise ValueError("Invalid conversation config - either configure offline chat or openai chat") raise ValueError("Invalid conversation config - either configure offline chat or openai chat")

View file

@ -16,7 +16,7 @@ from unfold import admin as unfold_admin
from khoj.database.models import ( from khoj.database.models import (
Agent, Agent,
AiModelApi, AiModelApi,
ChatModelOptions, ChatModel,
ClientApplication, ClientApplication,
Conversation, Conversation,
Entry, Entry,
@ -212,15 +212,15 @@ class KhojUserSubscription(unfold_admin.ModelAdmin):
list_filter = ("type",) list_filter = ("type",)
@admin.register(ChatModelOptions) @admin.register(ChatModel)
class ChatModelOptionsAdmin(unfold_admin.ModelAdmin): class ChatModelAdmin(unfold_admin.ModelAdmin):
list_display = ( list_display = (
"id", "id",
"chat_model", "name",
"ai_model_api", "ai_model_api",
"max_prompt_size", "max_prompt_size",
) )
search_fields = ("id", "chat_model", "ai_model_api__name") search_fields = ("id", "name", "ai_model_api__name")
@admin.register(TextToImageModelConfig) @admin.register(TextToImageModelConfig)
@ -385,7 +385,7 @@ class UserConversationConfigAdmin(unfold_admin.ModelAdmin):
"get_chat_model", "get_chat_model",
"get_subscription_type", "get_subscription_type",
) )
search_fields = ("id", "user__email", "setting__chat_model", "user__subscription__type") search_fields = ("id", "user__email", "setting__name", "user__subscription__type")
ordering = ("-updated_at",) ordering = ("-updated_at",)
def get_user_email(self, obj): def get_user_email(self, obj):
@ -395,10 +395,10 @@ class UserConversationConfigAdmin(unfold_admin.ModelAdmin):
get_user_email.admin_order_field = "user__email" # type: ignore get_user_email.admin_order_field = "user__email" # type: ignore
def get_chat_model(self, obj): def get_chat_model(self, obj):
return obj.setting.chat_model if obj.setting else None return obj.setting.name if obj.setting else None
get_chat_model.short_description = "Chat Model" # type: ignore get_chat_model.short_description = "Chat Model" # type: ignore
get_chat_model.admin_order_field = "setting__chat_model" # type: ignore get_chat_model.admin_order_field = "setting__name" # type: ignore
def get_subscription_type(self, obj): def get_subscription_type(self, obj):
if hasattr(obj.user, "subscription"): if hasattr(obj.user, "subscription"):

View file

@ -0,0 +1,62 @@
# Generated by Django 5.0.9 on 2024-12-09 04:21
import django.db.models.deletion
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("database", "0076_rename_openaiprocessorconversationconfig_aimodelapi_and_more"),
]
operations = [
migrations.RenameModel(
old_name="ChatModelOptions",
new_name="ChatModel",
),
migrations.RenameField(
model_name="chatmodel",
old_name="chat_model",
new_name="name",
),
migrations.AlterField(
model_name="agent",
name="chat_model",
field=models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to="database.chatmodel"),
),
migrations.AlterField(
model_name="serverchatsettings",
name="chat_advanced",
field=models.ForeignKey(
blank=True,
default=None,
null=True,
on_delete=django.db.models.deletion.CASCADE,
related_name="chat_advanced",
to="database.chatmodel",
),
),
migrations.AlterField(
model_name="serverchatsettings",
name="chat_default",
field=models.ForeignKey(
blank=True,
default=None,
null=True,
on_delete=django.db.models.deletion.CASCADE,
related_name="chat_default",
to="database.chatmodel",
),
),
migrations.AlterField(
model_name="userconversationconfig",
name="setting",
field=models.ForeignKey(
blank=True,
default=None,
null=True,
on_delete=django.db.models.deletion.CASCADE,
to="database.chatmodel",
),
),
]

View file

@ -193,7 +193,7 @@ class AiModelApi(DbBaseModel):
return self.name return self.name
class ChatModelOptions(DbBaseModel): class ChatModel(DbBaseModel):
class ModelType(models.TextChoices): class ModelType(models.TextChoices):
OPENAI = "openai" OPENAI = "openai"
OFFLINE = "offline" OFFLINE = "offline"
@ -203,13 +203,13 @@ class ChatModelOptions(DbBaseModel):
max_prompt_size = models.IntegerField(default=None, null=True, blank=True) max_prompt_size = models.IntegerField(default=None, null=True, blank=True)
subscribed_max_prompt_size = models.IntegerField(default=None, null=True, blank=True) subscribed_max_prompt_size = models.IntegerField(default=None, null=True, blank=True)
tokenizer = models.CharField(max_length=200, default=None, null=True, blank=True) tokenizer = models.CharField(max_length=200, default=None, null=True, blank=True)
chat_model = models.CharField(max_length=200, default="bartowski/Meta-Llama-3.1-8B-Instruct-GGUF") name = models.CharField(max_length=200, default="bartowski/Meta-Llama-3.1-8B-Instruct-GGUF")
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE) model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE)
vision_enabled = models.BooleanField(default=False) vision_enabled = models.BooleanField(default=False)
ai_model_api = models.ForeignKey(AiModelApi, on_delete=models.CASCADE, default=None, null=True, blank=True) ai_model_api = models.ForeignKey(AiModelApi, on_delete=models.CASCADE, default=None, null=True, blank=True)
def __str__(self): def __str__(self):
return self.chat_model return self.name
class VoiceModelOption(DbBaseModel): class VoiceModelOption(DbBaseModel):
@ -297,7 +297,7 @@ class Agent(DbBaseModel):
models.CharField(max_length=200, choices=OutputModeOptions.choices), default=list, null=True, blank=True models.CharField(max_length=200, choices=OutputModeOptions.choices), default=list, null=True, blank=True
) )
managed_by_admin = models.BooleanField(default=False) managed_by_admin = models.BooleanField(default=False)
chat_model = models.ForeignKey(ChatModelOptions, on_delete=models.CASCADE) chat_model = models.ForeignKey(ChatModel, on_delete=models.CASCADE)
slug = models.CharField(max_length=200, unique=True) slug = models.CharField(max_length=200, unique=True)
style_color = models.CharField(max_length=200, choices=StyleColorTypes.choices, default=StyleColorTypes.BLUE) style_color = models.CharField(max_length=200, choices=StyleColorTypes.choices, default=StyleColorTypes.BLUE)
style_icon = models.CharField(max_length=200, choices=StyleIconTypes.choices, default=StyleIconTypes.LIGHTBULB) style_icon = models.CharField(max_length=200, choices=StyleIconTypes.choices, default=StyleIconTypes.LIGHTBULB)
@ -438,10 +438,10 @@ class WebScraper(DbBaseModel):
class ServerChatSettings(DbBaseModel): class ServerChatSettings(DbBaseModel):
chat_default = models.ForeignKey( chat_default = models.ForeignKey(
ChatModelOptions, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="chat_default" ChatModel, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="chat_default"
) )
chat_advanced = models.ForeignKey( chat_advanced = models.ForeignKey(
ChatModelOptions, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="chat_advanced" ChatModel, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="chat_advanced"
) )
web_scraper = models.ForeignKey( web_scraper = models.ForeignKey(
WebScraper, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="web_scraper" WebScraper, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="web_scraper"
@ -563,7 +563,7 @@ class SpeechToTextModelOptions(DbBaseModel):
class UserConversationConfig(DbBaseModel): class UserConversationConfig(DbBaseModel):
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE) user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
setting = models.ForeignKey(ChatModelOptions, on_delete=models.CASCADE, default=None, null=True, blank=True) setting = models.ForeignKey(ChatModel, on_delete=models.CASCADE, default=None, null=True, blank=True)
class UserVoiceModelConfig(DbBaseModel): class UserVoiceModelConfig(DbBaseModel):

View file

@ -60,7 +60,7 @@ import logging
from packaging import version from packaging import version
from khoj.database.models import AiModelApi, ChatModelOptions, SearchModelConfig from khoj.database.models import AiModelApi, ChatModel, SearchModelConfig
from khoj.utils.yaml import load_config_from_file, save_config_to_file from khoj.utils.yaml import load_config_from_file, save_config_to_file
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -98,11 +98,11 @@ def migrate_server_pg(args):
if "offline-chat" in raw_config["processor"]["conversation"]: if "offline-chat" in raw_config["processor"]["conversation"]:
offline_chat = raw_config["processor"]["conversation"]["offline-chat"] offline_chat = raw_config["processor"]["conversation"]["offline-chat"]
ChatModelOptions.objects.create( ChatModel.objects.create(
chat_model=offline_chat.get("chat-model"), name=offline_chat.get("chat-model"),
tokenizer=processor_conversation.get("tokenizer"), tokenizer=processor_conversation.get("tokenizer"),
max_prompt_size=processor_conversation.get("max-prompt-size"), max_prompt_size=processor_conversation.get("max-prompt-size"),
model_type=ChatModelOptions.ModelType.OFFLINE, model_type=ChatModel.ModelType.OFFLINE,
) )
if ( if (
@ -119,11 +119,11 @@ def migrate_server_pg(args):
openai_model_api = AiModelApi.objects.create(api_key=openai.get("api-key"), name="default") openai_model_api = AiModelApi.objects.create(api_key=openai.get("api-key"), name="default")
ChatModelOptions.objects.create( ChatModel.objects.create(
chat_model=openai.get("chat-model"), name=openai.get("chat-model"),
tokenizer=processor_conversation.get("tokenizer"), tokenizer=processor_conversation.get("tokenizer"),
max_prompt_size=processor_conversation.get("max-prompt-size"), max_prompt_size=processor_conversation.get("max-prompt-size"),
model_type=ChatModelOptions.ModelType.OPENAI, model_type=ChatModel.ModelType.OPENAI,
ai_model_api=openai_model_api, ai_model_api=openai_model_api,
) )

View file

@ -5,7 +5,7 @@ from typing import Dict, List, Optional
import pyjson5 import pyjson5
from langchain.schema import ChatMessage from langchain.schema import ChatMessage
from khoj.database.models import Agent, ChatModelOptions, KhojUser from khoj.database.models import Agent, ChatModel, KhojUser
from khoj.processor.conversation import prompts from khoj.processor.conversation import prompts
from khoj.processor.conversation.anthropic.utils import ( from khoj.processor.conversation.anthropic.utils import (
anthropic_chat_completion_with_backoff, anthropic_chat_completion_with_backoff,
@ -85,7 +85,7 @@ def extract_questions_anthropic(
prompt = construct_structured_message( prompt = construct_structured_message(
message=prompt, message=prompt,
images=query_images, images=query_images,
model_type=ChatModelOptions.ModelType.ANTHROPIC, model_type=ChatModel.ModelType.ANTHROPIC,
vision_enabled=vision_enabled, vision_enabled=vision_enabled,
attached_file_context=query_files, attached_file_context=query_files,
) )
@ -218,7 +218,7 @@ def converse_anthropic(
tokenizer_name=tokenizer_name, tokenizer_name=tokenizer_name,
query_images=query_images, query_images=query_images,
vision_enabled=vision_available, vision_enabled=vision_available,
model_type=ChatModelOptions.ModelType.ANTHROPIC, model_type=ChatModel.ModelType.ANTHROPIC,
query_files=query_files, query_files=query_files,
generated_files=generated_files, generated_files=generated_files,
generated_asset_results=generated_asset_results, generated_asset_results=generated_asset_results,

View file

@ -5,7 +5,7 @@ from typing import Dict, List, Optional
import pyjson5 import pyjson5
from langchain.schema import ChatMessage from langchain.schema import ChatMessage
from khoj.database.models import Agent, ChatModelOptions, KhojUser from khoj.database.models import Agent, ChatModel, KhojUser
from khoj.processor.conversation import prompts from khoj.processor.conversation import prompts
from khoj.processor.conversation.google.utils import ( from khoj.processor.conversation.google.utils import (
format_messages_for_gemini, format_messages_for_gemini,
@ -86,7 +86,7 @@ def extract_questions_gemini(
prompt = construct_structured_message( prompt = construct_structured_message(
message=prompt, message=prompt,
images=query_images, images=query_images,
model_type=ChatModelOptions.ModelType.GOOGLE, model_type=ChatModel.ModelType.GOOGLE,
vision_enabled=vision_enabled, vision_enabled=vision_enabled,
attached_file_context=query_files, attached_file_context=query_files,
) )
@ -229,7 +229,7 @@ def converse_gemini(
tokenizer_name=tokenizer_name, tokenizer_name=tokenizer_name,
query_images=query_images, query_images=query_images,
vision_enabled=vision_available, vision_enabled=vision_available,
model_type=ChatModelOptions.ModelType.GOOGLE, model_type=ChatModel.ModelType.GOOGLE,
query_files=query_files, query_files=query_files,
generated_files=generated_files, generated_files=generated_files,
generated_asset_results=generated_asset_results, generated_asset_results=generated_asset_results,

View file

@ -9,7 +9,7 @@ import pyjson5
from langchain.schema import ChatMessage from langchain.schema import ChatMessage
from llama_cpp import Llama from llama_cpp import Llama
from khoj.database.models import Agent, ChatModelOptions, KhojUser from khoj.database.models import Agent, ChatModel, KhojUser
from khoj.processor.conversation import prompts from khoj.processor.conversation import prompts
from khoj.processor.conversation.offline.utils import download_model from khoj.processor.conversation.offline.utils import download_model
from khoj.processor.conversation.utils import ( from khoj.processor.conversation.utils import (
@ -96,7 +96,7 @@ def extract_questions_offline(
model_name=model, model_name=model,
loaded_model=offline_chat_model, loaded_model=offline_chat_model,
max_prompt_size=max_prompt_size, max_prompt_size=max_prompt_size,
model_type=ChatModelOptions.ModelType.OFFLINE, model_type=ChatModel.ModelType.OFFLINE,
query_files=query_files, query_files=query_files,
) )
@ -105,7 +105,7 @@ def extract_questions_offline(
response = send_message_to_model_offline( response = send_message_to_model_offline(
messages, messages,
loaded_model=offline_chat_model, loaded_model=offline_chat_model,
model=model, model_name=model,
max_prompt_size=max_prompt_size, max_prompt_size=max_prompt_size,
temperature=temperature, temperature=temperature,
response_type="json_object", response_type="json_object",
@ -154,7 +154,7 @@ def converse_offline(
online_results={}, online_results={},
code_results={}, code_results={},
conversation_log={}, conversation_log={},
model: str = "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF", model_name: str = "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF",
loaded_model: Union[Any, None] = None, loaded_model: Union[Any, None] = None,
completion_func=None, completion_func=None,
conversation_commands=[ConversationCommand.Default], conversation_commands=[ConversationCommand.Default],
@ -174,8 +174,8 @@ def converse_offline(
""" """
# Initialize Variables # Initialize Variables
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured" assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size) offline_chat_model = loaded_model or download_model(model_name, max_tokens=max_prompt_size)
tracer["chat_model"] = model tracer["chat_model"] = model_name
current_date = datetime.now() current_date = datetime.now()
if agent and agent.personality: if agent and agent.personality:
@ -228,18 +228,18 @@ def converse_offline(
system_prompt, system_prompt,
conversation_log, conversation_log,
context_message=context_message, context_message=context_message,
model_name=model, model_name=model_name,
loaded_model=offline_chat_model, loaded_model=offline_chat_model,
max_prompt_size=max_prompt_size, max_prompt_size=max_prompt_size,
tokenizer_name=tokenizer_name, tokenizer_name=tokenizer_name,
model_type=ChatModelOptions.ModelType.OFFLINE, model_type=ChatModel.ModelType.OFFLINE,
query_files=query_files, query_files=query_files,
generated_files=generated_files, generated_files=generated_files,
generated_asset_results=generated_asset_results, generated_asset_results=generated_asset_results,
program_execution_context=additional_context, program_execution_context=additional_context,
) )
logger.debug(f"Conversation Context for {model}: {messages_to_print(messages)}") logger.debug(f"Conversation Context for {model_name}: {messages_to_print(messages)}")
g = ThreadedGenerator(references, online_results, completion_func=completion_func) g = ThreadedGenerator(references, online_results, completion_func=completion_func)
t = Thread(target=llm_thread, args=(g, messages, offline_chat_model, max_prompt_size, tracer)) t = Thread(target=llm_thread, args=(g, messages, offline_chat_model, max_prompt_size, tracer))
@ -273,7 +273,7 @@ def llm_thread(g, messages: List[ChatMessage], model: Any, max_prompt_size: int
def send_message_to_model_offline( def send_message_to_model_offline(
messages: List[ChatMessage], messages: List[ChatMessage],
loaded_model=None, loaded_model=None,
model="bartowski/Meta-Llama-3.1-8B-Instruct-GGUF", model_name="bartowski/Meta-Llama-3.1-8B-Instruct-GGUF",
temperature: float = 0.2, temperature: float = 0.2,
streaming=False, streaming=False,
stop=[], stop=[],
@ -282,7 +282,7 @@ def send_message_to_model_offline(
tracer: dict = {}, tracer: dict = {},
): ):
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured" assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size) offline_chat_model = loaded_model or download_model(model_name, max_tokens=max_prompt_size)
messages_dict = [{"role": message.role, "content": message.content} for message in messages] messages_dict = [{"role": message.role, "content": message.content} for message in messages]
seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None
response = offline_chat_model.create_chat_completion( response = offline_chat_model.create_chat_completion(
@ -301,7 +301,7 @@ def send_message_to_model_offline(
# Save conversation trace for non-streaming responses # Save conversation trace for non-streaming responses
# Streamed responses need to be saved by the calling function # Streamed responses need to be saved by the calling function
tracer["chat_model"] = model tracer["chat_model"] = model_name
tracer["temperature"] = temperature tracer["temperature"] = temperature
if is_promptrace_enabled(): if is_promptrace_enabled():
commit_conversation_trace(messages, response_text, tracer) commit_conversation_trace(messages, response_text, tracer)

View file

@ -5,7 +5,7 @@ from typing import Dict, List, Optional
import pyjson5 import pyjson5
from langchain.schema import ChatMessage from langchain.schema import ChatMessage
from khoj.database.models import Agent, ChatModelOptions, KhojUser from khoj.database.models import Agent, ChatModel, KhojUser
from khoj.processor.conversation import prompts from khoj.processor.conversation import prompts
from khoj.processor.conversation.openai.utils import ( from khoj.processor.conversation.openai.utils import (
chat_completion_with_backoff, chat_completion_with_backoff,
@ -83,7 +83,7 @@ def extract_questions(
prompt = construct_structured_message( prompt = construct_structured_message(
message=prompt, message=prompt,
images=query_images, images=query_images,
model_type=ChatModelOptions.ModelType.OPENAI, model_type=ChatModel.ModelType.OPENAI,
vision_enabled=vision_enabled, vision_enabled=vision_enabled,
attached_file_context=query_files, attached_file_context=query_files,
) )
@ -128,7 +128,7 @@ def send_message_to_model(
# Get Response from GPT # Get Response from GPT
return completion_with_backoff( return completion_with_backoff(
messages=messages, messages=messages,
model=model, model_name=model,
openai_api_key=api_key, openai_api_key=api_key,
temperature=temperature, temperature=temperature,
api_base_url=api_base_url, api_base_url=api_base_url,
@ -220,7 +220,7 @@ def converse_openai(
tokenizer_name=tokenizer_name, tokenizer_name=tokenizer_name,
query_images=query_images, query_images=query_images,
vision_enabled=vision_available, vision_enabled=vision_available,
model_type=ChatModelOptions.ModelType.OPENAI, model_type=ChatModel.ModelType.OPENAI,
query_files=query_files, query_files=query_files,
generated_files=generated_files, generated_files=generated_files,
generated_asset_results=generated_asset_results, generated_asset_results=generated_asset_results,

View file

@ -40,7 +40,13 @@ openai_clients: Dict[str, openai.OpenAI] = {}
reraise=True, reraise=True,
) )
def completion_with_backoff( def completion_with_backoff(
messages, model, temperature=0, openai_api_key=None, api_base_url=None, model_kwargs=None, tracer: dict = {} messages,
model_name: str,
temperature=0,
openai_api_key=None,
api_base_url=None,
model_kwargs: dict = {},
tracer: dict = {},
) -> str: ) -> str:
client_key = f"{openai_api_key}--{api_base_url}" client_key = f"{openai_api_key}--{api_base_url}"
client: openai.OpenAI | None = openai_clients.get(client_key) client: openai.OpenAI | None = openai_clients.get(client_key)
@ -56,7 +62,7 @@ def completion_with_backoff(
# Update request parameters for compatability with o1 model series # Update request parameters for compatability with o1 model series
# Refer: https://platform.openai.com/docs/guides/reasoning/beta-limitations # Refer: https://platform.openai.com/docs/guides/reasoning/beta-limitations
if model.startswith("o1"): if model_name.startswith("o1"):
temperature = 1 temperature = 1
model_kwargs.pop("stop", None) model_kwargs.pop("stop", None)
model_kwargs.pop("response_format", None) model_kwargs.pop("response_format", None)
@ -66,12 +72,12 @@ def completion_with_backoff(
chat: ChatCompletion | openai.Stream[ChatCompletionChunk] = client.chat.completions.create( chat: ChatCompletion | openai.Stream[ChatCompletionChunk] = client.chat.completions.create(
messages=formatted_messages, # type: ignore messages=formatted_messages, # type: ignore
model=model, # type: ignore model=model_name, # type: ignore
stream=stream, stream=stream,
stream_options={"include_usage": True} if stream else {}, stream_options={"include_usage": True} if stream else {},
temperature=temperature, temperature=temperature,
timeout=20, timeout=20,
**(model_kwargs or dict()), **model_kwargs,
) )
aggregated_response = "" aggregated_response = ""
@ -91,10 +97,11 @@ def completion_with_backoff(
# Calculate cost of chat # Calculate cost of chat
input_tokens = chunk.usage.prompt_tokens if hasattr(chunk, "usage") and chunk.usage else 0 input_tokens = chunk.usage.prompt_tokens if hasattr(chunk, "usage") and chunk.usage else 0
output_tokens = chunk.usage.completion_tokens if hasattr(chunk, "usage") and chunk.usage else 0 output_tokens = chunk.usage.completion_tokens if hasattr(chunk, "usage") and chunk.usage else 0
tracer["usage"] = get_chat_usage_metrics(model, input_tokens, output_tokens, tracer.get("usage")) cost = chunk.usage.model_extra.get("estimated_cost") or 0 # Estimated costs returned by DeepInfra API
tracer["usage"] = get_chat_usage_metrics(model_name, input_tokens, output_tokens, tracer.get("usage"), cost)
# Save conversation trace # Save conversation trace
tracer["chat_model"] = model tracer["chat_model"] = model_name
tracer["temperature"] = temperature tracer["temperature"] = temperature
if is_promptrace_enabled(): if is_promptrace_enabled():
commit_conversation_trace(messages, aggregated_response, tracer) commit_conversation_trace(messages, aggregated_response, tracer)
@ -139,11 +146,11 @@ def chat_completion_with_backoff(
def llm_thread( def llm_thread(
g, g,
messages, messages,
model_name, model_name: str,
temperature, temperature,
openai_api_key=None, openai_api_key=None,
api_base_url=None, api_base_url=None,
model_kwargs=None, model_kwargs: dict = {},
tracer: dict = {}, tracer: dict = {},
): ):
try: try:
@ -177,7 +184,7 @@ def llm_thread(
stream_options={"include_usage": True} if stream else {}, stream_options={"include_usage": True} if stream else {},
temperature=temperature, temperature=temperature,
timeout=20, timeout=20,
**(model_kwargs or dict()), **model_kwargs,
) )
aggregated_response = "" aggregated_response = ""
@ -202,7 +209,8 @@ def llm_thread(
# Calculate cost of chat # Calculate cost of chat
input_tokens = chunk.usage.prompt_tokens if hasattr(chunk, "usage") and chunk.usage else 0 input_tokens = chunk.usage.prompt_tokens if hasattr(chunk, "usage") and chunk.usage else 0
output_tokens = chunk.usage.completion_tokens if hasattr(chunk, "usage") and chunk.usage else 0 output_tokens = chunk.usage.completion_tokens if hasattr(chunk, "usage") and chunk.usage else 0
tracer["usage"] = get_chat_usage_metrics(model_name, input_tokens, output_tokens, tracer.get("usage")) cost = chunk.usage.model_extra.get("estimated_cost") or 0 # Estimated costs returned by DeepInfra API
tracer["usage"] = get_chat_usage_metrics(model_name, input_tokens, output_tokens, tracer.get("usage"), cost)
# Save conversation trace # Save conversation trace
tracer["chat_model"] = model_name tracer["chat_model"] = model_name

View file

@ -24,7 +24,7 @@ from llama_cpp.llama import Llama
from transformers import AutoTokenizer from transformers import AutoTokenizer
from khoj.database.adapters import ConversationAdapters from khoj.database.adapters import ConversationAdapters
from khoj.database.models import ChatModelOptions, ClientApplication, KhojUser from khoj.database.models import ChatModel, ClientApplication, KhojUser
from khoj.processor.conversation import prompts from khoj.processor.conversation import prompts
from khoj.processor.conversation.offline.utils import download_model, infer_max_tokens from khoj.processor.conversation.offline.utils import download_model, infer_max_tokens
from khoj.search_filter.base_filter import BaseFilter from khoj.search_filter.base_filter import BaseFilter
@ -330,9 +330,9 @@ def construct_structured_message(
Format messages into appropriate multimedia format for supported chat model types Format messages into appropriate multimedia format for supported chat model types
""" """
if model_type in [ if model_type in [
ChatModelOptions.ModelType.OPENAI, ChatModel.ModelType.OPENAI,
ChatModelOptions.ModelType.GOOGLE, ChatModel.ModelType.GOOGLE,
ChatModelOptions.ModelType.ANTHROPIC, ChatModel.ModelType.ANTHROPIC,
]: ]:
if not attached_file_context and not (vision_enabled and images): if not attached_file_context and not (vision_enabled and images):
return message return message

View file

@ -28,12 +28,7 @@ from khoj.database.adapters import (
get_default_search_model, get_default_search_model,
get_user_photo, get_user_photo,
) )
from khoj.database.models import ( from khoj.database.models import Agent, ChatModel, KhojUser, SpeechToTextModelOptions
Agent,
ChatModelOptions,
KhojUser,
SpeechToTextModelOptions,
)
from khoj.processor.conversation import prompts from khoj.processor.conversation import prompts
from khoj.processor.conversation.anthropic.anthropic_chat import ( from khoj.processor.conversation.anthropic.anthropic_chat import (
extract_questions_anthropic, extract_questions_anthropic,
@ -404,15 +399,15 @@ async def extract_references_and_questions(
# Infer search queries from user message # Infer search queries from user message
with timer("Extracting search queries took", logger): with timer("Extracting search queries took", logger):
# If we've reached here, either the user has enabled offline chat or the openai model is enabled. # If we've reached here, either the user has enabled offline chat or the openai model is enabled.
conversation_config = await ConversationAdapters.aget_default_conversation_config(user) chat_model = await ConversationAdapters.aget_default_chat_model(user)
vision_enabled = conversation_config.vision_enabled vision_enabled = chat_model.vision_enabled
if conversation_config.model_type == ChatModelOptions.ModelType.OFFLINE: if chat_model.model_type == ChatModel.ModelType.OFFLINE:
using_offline_chat = True using_offline_chat = True
chat_model = conversation_config.chat_model chat_model_name = chat_model.name
max_tokens = conversation_config.max_prompt_size max_tokens = chat_model.max_prompt_size
if state.offline_chat_processor_config is None: if state.offline_chat_processor_config is None:
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens) state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model_name, max_tokens)
loaded_model = state.offline_chat_processor_config.loaded_model loaded_model = state.offline_chat_processor_config.loaded_model
@ -424,18 +419,18 @@ async def extract_references_and_questions(
should_extract_questions=True, should_extract_questions=True,
location_data=location_data, location_data=location_data,
user=user, user=user,
max_prompt_size=conversation_config.max_prompt_size, max_prompt_size=chat_model.max_prompt_size,
personality_context=personality_context, personality_context=personality_context,
query_files=query_files, query_files=query_files,
tracer=tracer, tracer=tracer,
) )
elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI: elif chat_model.model_type == ChatModel.ModelType.OPENAI:
api_key = conversation_config.ai_model_api.api_key api_key = chat_model.ai_model_api.api_key
base_url = conversation_config.ai_model_api.api_base_url base_url = chat_model.ai_model_api.api_base_url
chat_model = conversation_config.chat_model chat_model_name = chat_model.name
inferred_queries = extract_questions( inferred_queries = extract_questions(
defiltered_query, defiltered_query,
model=chat_model, model=chat_model_name,
api_key=api_key, api_key=api_key,
api_base_url=base_url, api_base_url=base_url,
conversation_log=meta_log, conversation_log=meta_log,
@ -447,13 +442,13 @@ async def extract_references_and_questions(
query_files=query_files, query_files=query_files,
tracer=tracer, tracer=tracer,
) )
elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC: elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC:
api_key = conversation_config.ai_model_api.api_key api_key = chat_model.ai_model_api.api_key
chat_model = conversation_config.chat_model chat_model_name = chat_model.name
inferred_queries = extract_questions_anthropic( inferred_queries = extract_questions_anthropic(
defiltered_query, defiltered_query,
query_images=query_images, query_images=query_images,
model=chat_model, model=chat_model_name,
api_key=api_key, api_key=api_key,
conversation_log=meta_log, conversation_log=meta_log,
location_data=location_data, location_data=location_data,
@ -463,17 +458,17 @@ async def extract_references_and_questions(
query_files=query_files, query_files=query_files,
tracer=tracer, tracer=tracer,
) )
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE: elif chat_model.model_type == ChatModel.ModelType.GOOGLE:
api_key = conversation_config.ai_model_api.api_key api_key = chat_model.ai_model_api.api_key
chat_model = conversation_config.chat_model chat_model_name = chat_model.name
inferred_queries = extract_questions_gemini( inferred_queries = extract_questions_gemini(
defiltered_query, defiltered_query,
query_images=query_images, query_images=query_images,
model=chat_model, model=chat_model_name,
api_key=api_key, api_key=api_key,
conversation_log=meta_log, conversation_log=meta_log,
location_data=location_data, location_data=location_data,
max_tokens=conversation_config.max_prompt_size, max_tokens=chat_model.max_prompt_size,
user=user, user=user,
vision_enabled=vision_enabled, vision_enabled=vision_enabled,
personality_context=personality_context, personality_context=personality_context,

View file

@ -62,7 +62,7 @@ async def all_agents(
"color": agent.style_color, "color": agent.style_color,
"icon": agent.style_icon, "icon": agent.style_icon,
"privacy_level": agent.privacy_level, "privacy_level": agent.privacy_level,
"chat_model": agent.chat_model.chat_model, "chat_model": agent.chat_model.name,
"files": file_names, "files": file_names,
"input_tools": agent.input_tools, "input_tools": agent.input_tools,
"output_modes": agent.output_modes, "output_modes": agent.output_modes,
@ -150,7 +150,7 @@ async def get_agent(
"color": agent.style_color, "color": agent.style_color,
"icon": agent.style_icon, "icon": agent.style_icon,
"privacy_level": agent.privacy_level, "privacy_level": agent.privacy_level,
"chat_model": agent.chat_model.chat_model, "chat_model": agent.chat_model.name,
"files": file_names, "files": file_names,
"input_tools": agent.input_tools, "input_tools": agent.input_tools,
"output_modes": agent.output_modes, "output_modes": agent.output_modes,
@ -225,7 +225,7 @@ async def create_agent(
"color": agent.style_color, "color": agent.style_color,
"icon": agent.style_icon, "icon": agent.style_icon,
"privacy_level": agent.privacy_level, "privacy_level": agent.privacy_level,
"chat_model": agent.chat_model.chat_model, "chat_model": agent.chat_model.name,
"files": body.files, "files": body.files,
"input_tools": agent.input_tools, "input_tools": agent.input_tools,
"output_modes": agent.output_modes, "output_modes": agent.output_modes,
@ -286,7 +286,7 @@ async def update_agent(
"color": agent.style_color, "color": agent.style_color,
"icon": agent.style_icon, "icon": agent.style_icon,
"privacy_level": agent.privacy_level, "privacy_level": agent.privacy_level,
"chat_model": agent.chat_model.chat_model, "chat_model": agent.chat_model.name,
"files": body.files, "files": body.files,
"input_tools": agent.input_tools, "input_tools": agent.input_tools,
"output_modes": agent.output_modes, "output_modes": agent.output_modes,

View file

@ -58,7 +58,7 @@ from khoj.routers.helpers import (
is_ready_to_chat, is_ready_to_chat,
read_chat_stream, read_chat_stream,
update_telemetry_state, update_telemetry_state,
validate_conversation_config, validate_chat_model,
) )
from khoj.routers.research import ( from khoj.routers.research import (
InformationCollectionIteration, InformationCollectionIteration,
@ -205,7 +205,7 @@ def chat_history(
n: Optional[int] = None, n: Optional[int] = None,
): ):
user = request.user.object user = request.user.object
validate_conversation_config(user) validate_chat_model(user)
# Load Conversation History # Load Conversation History
conversation = ConversationAdapters.get_conversation_by_user( conversation = ConversationAdapters.get_conversation_by_user(
@ -898,10 +898,10 @@ async def chat(
custom_filters = [] custom_filters = []
if conversation_commands == [ConversationCommand.Help]: if conversation_commands == [ConversationCommand.Help]:
if not q: if not q:
conversation_config = await ConversationAdapters.aget_user_conversation_config(user) chat_model = await ConversationAdapters.aget_user_chat_model(user)
if conversation_config == None: if chat_model == None:
conversation_config = await ConversationAdapters.aget_default_conversation_config(user) chat_model = await ConversationAdapters.aget_default_chat_model(user)
model_type = conversation_config.model_type model_type = chat_model.model_type
formatted_help = help_message.format(model=model_type, version=state.khoj_version, device=get_device()) formatted_help = help_message.format(model=model_type, version=state.khoj_version, device=get_device())
async for result in send_llm_response(formatted_help, tracer.get("usage")): async for result in send_llm_response(formatted_help, tracer.get("usage")):
yield result yield result

View file

@ -24,7 +24,7 @@ def get_chat_model_options(
all_conversation_options = list() all_conversation_options = list()
for conversation_option in conversation_options: for conversation_option in conversation_options:
all_conversation_options.append({"chat_model": conversation_option.chat_model, "id": conversation_option.id}) all_conversation_options.append({"chat_model": conversation_option.name, "id": conversation_option.id})
return Response(content=json.dumps(all_conversation_options), media_type="application/json", status_code=200) return Response(content=json.dumps(all_conversation_options), media_type="application/json", status_code=200)
@ -37,12 +37,12 @@ def get_user_chat_model(
): ):
user = request.user.object user = request.user.object
chat_model = ConversationAdapters.get_conversation_config(user) chat_model = ConversationAdapters.get_chat_model(user)
if chat_model is None: if chat_model is None:
chat_model = ConversationAdapters.get_default_conversation_config(user) chat_model = ConversationAdapters.get_default_chat_model(user)
return Response(status_code=200, content=json.dumps({"id": chat_model.id, "chat_model": chat_model.chat_model})) return Response(status_code=200, content=json.dumps({"id": chat_model.id, "chat_model": chat_model.name}))
@api_model.post("/chat", status_code=200) @api_model.post("/chat", status_code=200)

View file

@ -56,7 +56,7 @@ from khoj.database.adapters import (
) )
from khoj.database.models import ( from khoj.database.models import (
Agent, Agent,
ChatModelOptions, ChatModel,
ClientApplication, ClientApplication,
Conversation, Conversation,
GithubConfig, GithubConfig,
@ -133,40 +133,40 @@ def is_query_empty(query: str) -> bool:
return is_none_or_empty(query.strip()) return is_none_or_empty(query.strip())
def validate_conversation_config(user: KhojUser): def validate_chat_model(user: KhojUser):
default_config = ConversationAdapters.get_default_conversation_config(user) default_chat_model = ConversationAdapters.get_default_chat_model(user)
if default_config is None: if default_chat_model is None:
raise HTTPException(status_code=500, detail="Contact the server administrator to add a chat model.") raise HTTPException(status_code=500, detail="Contact the server administrator to add a chat model.")
if default_config.model_type == "openai" and not default_config.ai_model_api: if default_chat_model.model_type == "openai" and not default_chat_model.ai_model_api:
raise HTTPException(status_code=500, detail="Contact the server administrator to add a chat model.") raise HTTPException(status_code=500, detail="Contact the server administrator to add a chat model.")
async def is_ready_to_chat(user: KhojUser): async def is_ready_to_chat(user: KhojUser):
user_conversation_config = await ConversationAdapters.aget_user_conversation_config(user) user_chat_model = await ConversationAdapters.aget_user_chat_model(user)
if user_conversation_config == None: if user_chat_model == None:
user_conversation_config = await ConversationAdapters.aget_default_conversation_config(user) user_chat_model = await ConversationAdapters.aget_default_chat_model(user)
if user_conversation_config and user_conversation_config.model_type == ChatModelOptions.ModelType.OFFLINE: if user_chat_model and user_chat_model.model_type == ChatModel.ModelType.OFFLINE:
chat_model = user_conversation_config.chat_model chat_model_name = user_chat_model.name
max_tokens = user_conversation_config.max_prompt_size max_tokens = user_chat_model.max_prompt_size
if state.offline_chat_processor_config is None: if state.offline_chat_processor_config is None:
logger.info("Loading Offline Chat Model...") logger.info("Loading Offline Chat Model...")
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens) state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model_name, max_tokens)
return True return True
if ( if (
user_conversation_config user_chat_model
and ( and (
user_conversation_config.model_type user_chat_model.model_type
in [ in [
ChatModelOptions.ModelType.OPENAI, ChatModel.ModelType.OPENAI,
ChatModelOptions.ModelType.ANTHROPIC, ChatModel.ModelType.ANTHROPIC,
ChatModelOptions.ModelType.GOOGLE, ChatModel.ModelType.GOOGLE,
] ]
) )
and user_conversation_config.ai_model_api and user_chat_model.ai_model_api
): ):
return True return True
@ -942,120 +942,124 @@ async def send_message_to_model_wrapper(
query_files: str = None, query_files: str = None,
tracer: dict = {}, tracer: dict = {},
): ):
conversation_config: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config(user) chat_model: ChatModel = await ConversationAdapters.aget_default_chat_model(user)
vision_available = conversation_config.vision_enabled vision_available = chat_model.vision_enabled
if not vision_available and query_images: if not vision_available and query_images:
logger.warning(f"Vision is not enabled for default model: {conversation_config.chat_model}.") logger.warning(f"Vision is not enabled for default model: {chat_model.name}.")
vision_enabled_config = await ConversationAdapters.aget_vision_enabled_config() vision_enabled_config = await ConversationAdapters.aget_vision_enabled_config()
if vision_enabled_config: if vision_enabled_config:
conversation_config = vision_enabled_config chat_model = vision_enabled_config
vision_available = True vision_available = True
if vision_available and query_images: if vision_available and query_images:
logger.info(f"Using {conversation_config.chat_model} model to understand {len(query_images)} images.") logger.info(f"Using {chat_model.name} model to understand {len(query_images)} images.")
subscribed = await ais_user_subscribed(user) subscribed = await ais_user_subscribed(user)
chat_model = conversation_config.chat_model chat_model_name = chat_model.name
max_tokens = ( max_tokens = (
conversation_config.subscribed_max_prompt_size chat_model.subscribed_max_prompt_size
if subscribed and conversation_config.subscribed_max_prompt_size if subscribed and chat_model.subscribed_max_prompt_size
else conversation_config.max_prompt_size else chat_model.max_prompt_size
) )
tokenizer = conversation_config.tokenizer tokenizer = chat_model.tokenizer
model_type = conversation_config.model_type model_type = chat_model.model_type
vision_available = conversation_config.vision_enabled vision_available = chat_model.vision_enabled
if model_type == ChatModelOptions.ModelType.OFFLINE: if model_type == ChatModel.ModelType.OFFLINE:
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None: if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens) state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model_name, max_tokens)
loaded_model = state.offline_chat_processor_config.loaded_model loaded_model = state.offline_chat_processor_config.loaded_model
truncated_messages = generate_chatml_messages_with_context( truncated_messages = generate_chatml_messages_with_context(
user_message=query, user_message=query,
context_message=context, context_message=context,
system_message=system_message, system_message=system_message,
model_name=chat_model, model_name=chat_model_name,
loaded_model=loaded_model, loaded_model=loaded_model,
tokenizer_name=tokenizer, tokenizer_name=tokenizer,
max_prompt_size=max_tokens, max_prompt_size=max_tokens,
vision_enabled=vision_available, vision_enabled=vision_available,
model_type=conversation_config.model_type, model_type=chat_model.model_type,
query_files=query_files, query_files=query_files,
) )
return send_message_to_model_offline( return send_message_to_model_offline(
messages=truncated_messages, messages=truncated_messages,
loaded_model=loaded_model, loaded_model=loaded_model,
model=chat_model, model_name=chat_model_name,
max_prompt_size=max_tokens, max_prompt_size=max_tokens,
streaming=False, streaming=False,
response_type=response_type, response_type=response_type,
tracer=tracer, tracer=tracer,
) )
elif model_type == ChatModelOptions.ModelType.OPENAI: elif model_type == ChatModel.ModelType.OPENAI:
openai_chat_config = conversation_config.ai_model_api openai_chat_config = chat_model.ai_model_api
api_key = openai_chat_config.api_key api_key = openai_chat_config.api_key
api_base_url = openai_chat_config.api_base_url api_base_url = openai_chat_config.api_base_url
truncated_messages = generate_chatml_messages_with_context( truncated_messages = generate_chatml_messages_with_context(
user_message=query, user_message=query,
context_message=context, context_message=context,
system_message=system_message, system_message=system_message,
model_name=chat_model, model_name=chat_model_name,
max_prompt_size=max_tokens, max_prompt_size=max_tokens,
tokenizer_name=tokenizer, tokenizer_name=tokenizer,
vision_enabled=vision_available, vision_enabled=vision_available,
query_images=query_images, query_images=query_images,
model_type=conversation_config.model_type, model_type=chat_model.model_type,
query_files=query_files, query_files=query_files,
) )
return send_message_to_model( return send_message_to_model(
messages=truncated_messages, messages=truncated_messages,
api_key=api_key, api_key=api_key,
model=chat_model, model=chat_model_name,
response_type=response_type, response_type=response_type,
api_base_url=api_base_url, api_base_url=api_base_url,
tracer=tracer, tracer=tracer,
) )
elif model_type == ChatModelOptions.ModelType.ANTHROPIC: elif model_type == ChatModel.ModelType.ANTHROPIC:
api_key = conversation_config.ai_model_api.api_key api_key = chat_model.ai_model_api.api_key
truncated_messages = generate_chatml_messages_with_context( truncated_messages = generate_chatml_messages_with_context(
user_message=query, user_message=query,
context_message=context, context_message=context,
system_message=system_message, system_message=system_message,
model_name=chat_model, model_name=chat_model_name,
max_prompt_size=max_tokens, max_prompt_size=max_tokens,
tokenizer_name=tokenizer, tokenizer_name=tokenizer,
vision_enabled=vision_available, vision_enabled=vision_available,
query_images=query_images, query_images=query_images,
model_type=conversation_config.model_type, model_type=chat_model.model_type,
query_files=query_files, query_files=query_files,
) )
return anthropic_send_message_to_model( return anthropic_send_message_to_model(
messages=truncated_messages, messages=truncated_messages,
api_key=api_key, api_key=api_key,
model=chat_model, model=chat_model_name,
response_type=response_type, response_type=response_type,
tracer=tracer, tracer=tracer,
) )
elif model_type == ChatModelOptions.ModelType.GOOGLE: elif model_type == ChatModel.ModelType.GOOGLE:
api_key = conversation_config.ai_model_api.api_key api_key = chat_model.ai_model_api.api_key
truncated_messages = generate_chatml_messages_with_context( truncated_messages = generate_chatml_messages_with_context(
user_message=query, user_message=query,
context_message=context, context_message=context,
system_message=system_message, system_message=system_message,
model_name=chat_model, model_name=chat_model_name,
max_prompt_size=max_tokens, max_prompt_size=max_tokens,
tokenizer_name=tokenizer, tokenizer_name=tokenizer,
vision_enabled=vision_available, vision_enabled=vision_available,
query_images=query_images, query_images=query_images,
model_type=conversation_config.model_type, model_type=chat_model.model_type,
query_files=query_files, query_files=query_files,
) )
return gemini_send_message_to_model( return gemini_send_message_to_model(
messages=truncated_messages, api_key=api_key, model=chat_model, response_type=response_type, tracer=tracer messages=truncated_messages,
api_key=api_key,
model=chat_model_name,
response_type=response_type,
tracer=tracer,
) )
else: else:
raise HTTPException(status_code=500, detail="Invalid conversation config") raise HTTPException(status_code=500, detail="Invalid conversation config")
@ -1069,99 +1073,99 @@ def send_message_to_model_wrapper_sync(
query_files: str = "", query_files: str = "",
tracer: dict = {}, tracer: dict = {},
): ):
conversation_config: ChatModelOptions = ConversationAdapters.get_default_conversation_config(user) chat_model: ChatModel = ConversationAdapters.get_default_chat_model(user)
if conversation_config is None: if chat_model is None:
raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.") raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.")
chat_model = conversation_config.chat_model chat_model_name = chat_model.name
max_tokens = conversation_config.max_prompt_size max_tokens = chat_model.max_prompt_size
vision_available = conversation_config.vision_enabled vision_available = chat_model.vision_enabled
if conversation_config.model_type == ChatModelOptions.ModelType.OFFLINE: if chat_model.model_type == ChatModel.ModelType.OFFLINE:
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None: if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens) state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model_name, max_tokens)
loaded_model = state.offline_chat_processor_config.loaded_model loaded_model = state.offline_chat_processor_config.loaded_model
truncated_messages = generate_chatml_messages_with_context( truncated_messages = generate_chatml_messages_with_context(
user_message=message, user_message=message,
system_message=system_message, system_message=system_message,
model_name=chat_model, model_name=chat_model_name,
loaded_model=loaded_model, loaded_model=loaded_model,
max_prompt_size=max_tokens, max_prompt_size=max_tokens,
vision_enabled=vision_available, vision_enabled=vision_available,
model_type=conversation_config.model_type, model_type=chat_model.model_type,
query_files=query_files, query_files=query_files,
) )
return send_message_to_model_offline( return send_message_to_model_offline(
messages=truncated_messages, messages=truncated_messages,
loaded_model=loaded_model, loaded_model=loaded_model,
model=chat_model, model_name=chat_model_name,
max_prompt_size=max_tokens, max_prompt_size=max_tokens,
streaming=False, streaming=False,
response_type=response_type, response_type=response_type,
tracer=tracer, tracer=tracer,
) )
elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI: elif chat_model.model_type == ChatModel.ModelType.OPENAI:
api_key = conversation_config.ai_model_api.api_key api_key = chat_model.ai_model_api.api_key
truncated_messages = generate_chatml_messages_with_context( truncated_messages = generate_chatml_messages_with_context(
user_message=message, user_message=message,
system_message=system_message, system_message=system_message,
model_name=chat_model, model_name=chat_model_name,
max_prompt_size=max_tokens, max_prompt_size=max_tokens,
vision_enabled=vision_available, vision_enabled=vision_available,
model_type=conversation_config.model_type, model_type=chat_model.model_type,
query_files=query_files, query_files=query_files,
) )
openai_response = send_message_to_model( openai_response = send_message_to_model(
messages=truncated_messages, messages=truncated_messages,
api_key=api_key, api_key=api_key,
model=chat_model, model=chat_model_name,
response_type=response_type, response_type=response_type,
tracer=tracer, tracer=tracer,
) )
return openai_response return openai_response
elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC: elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC:
api_key = conversation_config.ai_model_api.api_key api_key = chat_model.ai_model_api.api_key
truncated_messages = generate_chatml_messages_with_context( truncated_messages = generate_chatml_messages_with_context(
user_message=message, user_message=message,
system_message=system_message, system_message=system_message,
model_name=chat_model, model_name=chat_model_name,
max_prompt_size=max_tokens, max_prompt_size=max_tokens,
vision_enabled=vision_available, vision_enabled=vision_available,
model_type=conversation_config.model_type, model_type=chat_model.model_type,
query_files=query_files, query_files=query_files,
) )
return anthropic_send_message_to_model( return anthropic_send_message_to_model(
messages=truncated_messages, messages=truncated_messages,
api_key=api_key, api_key=api_key,
model=chat_model, model=chat_model_name,
response_type=response_type, response_type=response_type,
tracer=tracer, tracer=tracer,
) )
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE: elif chat_model.model_type == ChatModel.ModelType.GOOGLE:
api_key = conversation_config.ai_model_api.api_key api_key = chat_model.ai_model_api.api_key
truncated_messages = generate_chatml_messages_with_context( truncated_messages = generate_chatml_messages_with_context(
user_message=message, user_message=message,
system_message=system_message, system_message=system_message,
model_name=chat_model, model_name=chat_model_name,
max_prompt_size=max_tokens, max_prompt_size=max_tokens,
vision_enabled=vision_available, vision_enabled=vision_available,
model_type=conversation_config.model_type, model_type=chat_model.model_type,
query_files=query_files, query_files=query_files,
) )
return gemini_send_message_to_model( return gemini_send_message_to_model(
messages=truncated_messages, messages=truncated_messages,
api_key=api_key, api_key=api_key,
model=chat_model, model=chat_model_name,
response_type=response_type, response_type=response_type,
tracer=tracer, tracer=tracer,
) )
@ -1229,15 +1233,15 @@ def generate_chat_response(
online_results = {} online_results = {}
code_results = {} code_results = {}
conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation) chat_model = ConversationAdapters.get_valid_chat_model(user, conversation)
vision_available = conversation_config.vision_enabled vision_available = chat_model.vision_enabled
if not vision_available and query_images: if not vision_available and query_images:
vision_enabled_config = ConversationAdapters.get_vision_enabled_config() vision_enabled_config = ConversationAdapters.get_vision_enabled_config()
if vision_enabled_config: if vision_enabled_config:
conversation_config = vision_enabled_config chat_model = vision_enabled_config
vision_available = True vision_available = True
if conversation_config.model_type == "offline": if chat_model.model_type == "offline":
loaded_model = state.offline_chat_processor_config.loaded_model loaded_model = state.offline_chat_processor_config.loaded_model
chat_response = converse_offline( chat_response = converse_offline(
user_query=query_to_run, user_query=query_to_run,
@ -1247,9 +1251,9 @@ def generate_chat_response(
conversation_log=meta_log, conversation_log=meta_log,
completion_func=partial_completion, completion_func=partial_completion,
conversation_commands=conversation_commands, conversation_commands=conversation_commands,
model=conversation_config.chat_model, model_name=chat_model.name,
max_prompt_size=conversation_config.max_prompt_size, max_prompt_size=chat_model.max_prompt_size,
tokenizer_name=conversation_config.tokenizer, tokenizer_name=chat_model.tokenizer,
location_data=location_data, location_data=location_data,
user_name=user_name, user_name=user_name,
agent=agent, agent=agent,
@ -1259,10 +1263,10 @@ def generate_chat_response(
tracer=tracer, tracer=tracer,
) )
elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI: elif chat_model.model_type == ChatModel.ModelType.OPENAI:
openai_chat_config = conversation_config.ai_model_api openai_chat_config = chat_model.ai_model_api
api_key = openai_chat_config.api_key api_key = openai_chat_config.api_key
chat_model = conversation_config.chat_model chat_model_name = chat_model.name
chat_response = converse_openai( chat_response = converse_openai(
compiled_references, compiled_references,
query_to_run, query_to_run,
@ -1270,13 +1274,13 @@ def generate_chat_response(
online_results=online_results, online_results=online_results,
code_results=code_results, code_results=code_results,
conversation_log=meta_log, conversation_log=meta_log,
model=chat_model, model=chat_model_name,
api_key=api_key, api_key=api_key,
api_base_url=openai_chat_config.api_base_url, api_base_url=openai_chat_config.api_base_url,
completion_func=partial_completion, completion_func=partial_completion,
conversation_commands=conversation_commands, conversation_commands=conversation_commands,
max_prompt_size=conversation_config.max_prompt_size, max_prompt_size=chat_model.max_prompt_size,
tokenizer_name=conversation_config.tokenizer, tokenizer_name=chat_model.tokenizer,
location_data=location_data, location_data=location_data,
user_name=user_name, user_name=user_name,
agent=agent, agent=agent,
@ -1288,8 +1292,8 @@ def generate_chat_response(
tracer=tracer, tracer=tracer,
) )
elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC: elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC:
api_key = conversation_config.ai_model_api.api_key api_key = chat_model.ai_model_api.api_key
chat_response = converse_anthropic( chat_response = converse_anthropic(
compiled_references, compiled_references,
query_to_run, query_to_run,
@ -1297,12 +1301,12 @@ def generate_chat_response(
online_results=online_results, online_results=online_results,
code_results=code_results, code_results=code_results,
conversation_log=meta_log, conversation_log=meta_log,
model=conversation_config.chat_model, model=chat_model.name,
api_key=api_key, api_key=api_key,
completion_func=partial_completion, completion_func=partial_completion,
conversation_commands=conversation_commands, conversation_commands=conversation_commands,
max_prompt_size=conversation_config.max_prompt_size, max_prompt_size=chat_model.max_prompt_size,
tokenizer_name=conversation_config.tokenizer, tokenizer_name=chat_model.tokenizer,
location_data=location_data, location_data=location_data,
user_name=user_name, user_name=user_name,
agent=agent, agent=agent,
@ -1313,20 +1317,20 @@ def generate_chat_response(
program_execution_context=program_execution_context, program_execution_context=program_execution_context,
tracer=tracer, tracer=tracer,
) )
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE: elif chat_model.model_type == ChatModel.ModelType.GOOGLE:
api_key = conversation_config.ai_model_api.api_key api_key = chat_model.ai_model_api.api_key
chat_response = converse_gemini( chat_response = converse_gemini(
compiled_references, compiled_references,
query_to_run, query_to_run,
online_results, online_results,
code_results, code_results,
meta_log, meta_log,
model=conversation_config.chat_model, model=chat_model.name,
api_key=api_key, api_key=api_key,
completion_func=partial_completion, completion_func=partial_completion,
conversation_commands=conversation_commands, conversation_commands=conversation_commands,
max_prompt_size=conversation_config.max_prompt_size, max_prompt_size=chat_model.max_prompt_size,
tokenizer_name=conversation_config.tokenizer, tokenizer_name=chat_model.tokenizer,
location_data=location_data, location_data=location_data,
user_name=user_name, user_name=user_name,
agent=agent, agent=agent,
@ -1339,7 +1343,7 @@ def generate_chat_response(
tracer=tracer, tracer=tracer,
) )
metadata.update({"chat_model": conversation_config.chat_model}) metadata.update({"chat_model": chat_model.name})
except Exception as e: except Exception as e:
logger.error(e, exc_info=True) logger.error(e, exc_info=True)
@ -1939,13 +1943,13 @@ def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False)
current_notion_config = get_user_notion_config(user) current_notion_config = get_user_notion_config(user)
notion_token = current_notion_config.token if current_notion_config else "" notion_token = current_notion_config.token if current_notion_config else ""
selected_chat_model_config = ConversationAdapters.get_conversation_config( selected_chat_model_config = ConversationAdapters.get_chat_model(
user user
) or ConversationAdapters.get_default_conversation_config(user) ) or ConversationAdapters.get_default_chat_model(user)
chat_models = ConversationAdapters.get_conversation_processor_options().all() chat_models = ConversationAdapters.get_conversation_processor_options().all()
chat_model_options = list() chat_model_options = list()
for chat_model in chat_models: for chat_model in chat_models:
chat_model_options.append({"name": chat_model.chat_model, "id": chat_model.id}) chat_model_options.append({"name": chat_model.name, "id": chat_model.id})
selected_paint_model_config = ConversationAdapters.get_user_text_to_image_model_config(user) selected_paint_model_config = ConversationAdapters.get_user_text_to_image_model_config(user)
paint_model_options = ConversationAdapters.get_text_to_image_model_options().all() paint_model_options = ConversationAdapters.get_text_to_image_model_options().all()

View file

@ -584,13 +584,15 @@ def get_cost_of_chat_message(model_name: str, input_tokens: int = 0, output_toke
return input_cost + output_cost + prev_cost return input_cost + output_cost + prev_cost
def get_chat_usage_metrics(model_name: str, input_tokens: int = 0, output_tokens: int = 0, usage: dict = {}): def get_chat_usage_metrics(
model_name: str, input_tokens: int = 0, output_tokens: int = 0, usage: dict = {}, cost: float = None
):
""" """
Get usage metrics for chat message based on input and output tokens Get usage metrics for chat message based on input and output tokens and cost
""" """
prev_usage = usage or {"input_tokens": 0, "output_tokens": 0, "cost": 0.0} prev_usage = usage or {"input_tokens": 0, "output_tokens": 0, "cost": 0.0}
return { return {
"input_tokens": prev_usage["input_tokens"] + input_tokens, "input_tokens": prev_usage["input_tokens"] + input_tokens,
"output_tokens": prev_usage["output_tokens"] + output_tokens, "output_tokens": prev_usage["output_tokens"] + output_tokens,
"cost": get_cost_of_chat_message(model_name, input_tokens, output_tokens, prev_cost=prev_usage["cost"]), "cost": cost or get_cost_of_chat_message(model_name, input_tokens, output_tokens, prev_cost=prev_usage["cost"]),
} }

View file

@ -7,7 +7,7 @@ import openai
from khoj.database.adapters import ConversationAdapters from khoj.database.adapters import ConversationAdapters
from khoj.database.models import ( from khoj.database.models import (
AiModelApi, AiModelApi,
ChatModelOptions, ChatModel,
KhojUser, KhojUser,
SpeechToTextModelOptions, SpeechToTextModelOptions,
TextToImageModelConfig, TextToImageModelConfig,
@ -63,7 +63,7 @@ def initialization(interactive: bool = True):
# Set up OpenAI's online chat models # Set up OpenAI's online chat models
openai_configured, openai_provider = _setup_chat_model_provider( openai_configured, openai_provider = _setup_chat_model_provider(
ChatModelOptions.ModelType.OPENAI, ChatModel.ModelType.OPENAI,
default_chat_models, default_chat_models,
default_api_key=openai_api_key, default_api_key=openai_api_key,
api_base_url=openai_api_base, api_base_url=openai_api_base,
@ -105,7 +105,7 @@ def initialization(interactive: bool = True):
# Set up Google's Gemini online chat models # Set up Google's Gemini online chat models
_setup_chat_model_provider( _setup_chat_model_provider(
ChatModelOptions.ModelType.GOOGLE, ChatModel.ModelType.GOOGLE,
default_gemini_chat_models, default_gemini_chat_models,
default_api_key=os.getenv("GEMINI_API_KEY"), default_api_key=os.getenv("GEMINI_API_KEY"),
vision_enabled=True, vision_enabled=True,
@ -116,7 +116,7 @@ def initialization(interactive: bool = True):
# Set up Anthropic's online chat models # Set up Anthropic's online chat models
_setup_chat_model_provider( _setup_chat_model_provider(
ChatModelOptions.ModelType.ANTHROPIC, ChatModel.ModelType.ANTHROPIC,
default_anthropic_chat_models, default_anthropic_chat_models,
default_api_key=os.getenv("ANTHROPIC_API_KEY"), default_api_key=os.getenv("ANTHROPIC_API_KEY"),
vision_enabled=True, vision_enabled=True,
@ -126,7 +126,7 @@ def initialization(interactive: bool = True):
# Set up offline chat models # Set up offline chat models
_setup_chat_model_provider( _setup_chat_model_provider(
ChatModelOptions.ModelType.OFFLINE, ChatModel.ModelType.OFFLINE,
default_offline_chat_models, default_offline_chat_models,
default_api_key=None, default_api_key=None,
vision_enabled=False, vision_enabled=False,
@ -135,9 +135,9 @@ def initialization(interactive: bool = True):
) )
# Explicitly set default chat model # Explicitly set default chat model
chat_models_configured = ChatModelOptions.objects.count() chat_models_configured = ChatModel.objects.count()
if chat_models_configured > 0: if chat_models_configured > 0:
default_chat_model_name = ChatModelOptions.objects.first().chat_model default_chat_model_name = ChatModel.objects.first().name
# If there are multiple chat models, ask the user to choose the default chat model # If there are multiple chat models, ask the user to choose the default chat model
if chat_models_configured > 1 and interactive: if chat_models_configured > 1 and interactive:
user_chat_model_name = input( user_chat_model_name = input(
@ -147,7 +147,7 @@ def initialization(interactive: bool = True):
user_chat_model_name = None user_chat_model_name = None
# If the user's choice is valid, set it as the default chat model # If the user's choice is valid, set it as the default chat model
if user_chat_model_name and ChatModelOptions.objects.filter(chat_model=user_chat_model_name).exists(): if user_chat_model_name and ChatModel.objects.filter(name=user_chat_model_name).exists():
default_chat_model_name = user_chat_model_name default_chat_model_name = user_chat_model_name
logger.info("🗣️ Chat model configuration complete") logger.info("🗣️ Chat model configuration complete")
@ -171,7 +171,7 @@ def initialization(interactive: bool = True):
logger.info(f"🗣️ Offline speech to text model configured to {offline_speech2text_model}") logger.info(f"🗣️ Offline speech to text model configured to {offline_speech2text_model}")
def _setup_chat_model_provider( def _setup_chat_model_provider(
model_type: ChatModelOptions.ModelType, model_type: ChatModel.ModelType,
default_chat_models: list, default_chat_models: list,
default_api_key: str, default_api_key: str,
interactive: bool, interactive: bool,
@ -204,10 +204,10 @@ def initialization(interactive: bool = True):
ai_model_api = AiModelApi.objects.create(api_key=api_key, name=provider_name, api_base_url=api_base_url) ai_model_api = AiModelApi.objects.create(api_key=api_key, name=provider_name, api_base_url=api_base_url)
if interactive: if interactive:
chat_model_names = input( user_chat_models = input(
f"Enter the {provider_name} chat models you want to use (default: {','.join(default_chat_models)}): " f"Enter the {provider_name} chat models you want to use (default: {','.join(default_chat_models)}): "
) )
chat_models = chat_model_names.split(",") if chat_model_names != "" else default_chat_models chat_models = user_chat_models.split(",") if user_chat_models != "" else default_chat_models
chat_models = [model.strip() for model in chat_models] chat_models = [model.strip() for model in chat_models]
else: else:
chat_models = default_chat_models chat_models = default_chat_models
@ -218,7 +218,7 @@ def initialization(interactive: bool = True):
vision_enabled = vision_enabled and chat_model in supported_vision_models vision_enabled = vision_enabled and chat_model in supported_vision_models
chat_model_options = { chat_model_options = {
"chat_model": chat_model, "name": chat_model,
"model_type": model_type, "model_type": model_type,
"max_prompt_size": default_max_tokens, "max_prompt_size": default_max_tokens,
"vision_enabled": vision_enabled, "vision_enabled": vision_enabled,
@ -226,7 +226,7 @@ def initialization(interactive: bool = True):
"ai_model_api": ai_model_api, "ai_model_api": ai_model_api,
} }
ChatModelOptions.objects.create(**chat_model_options) ChatModel.objects.create(**chat_model_options)
logger.info(f"🗣️ {provider_name} chat model configuration complete") logger.info(f"🗣️ {provider_name} chat model configuration complete")
return True, ai_model_api return True, ai_model_api
@ -250,19 +250,19 @@ def initialization(interactive: bool = True):
available_models = [model.id for model in openai_client.models.list()] available_models = [model.id for model in openai_client.models.list()]
# Get existing chat model options for this config # Get existing chat model options for this config
existing_models = ChatModelOptions.objects.filter( existing_models = ChatModel.objects.filter(
ai_model_api=config, model_type=ChatModelOptions.ModelType.OPENAI ai_model_api=config, model_type=ChatModel.ModelType.OPENAI
) )
# Add new models # Add new models
for model in available_models: for model_name in available_models:
if not existing_models.filter(chat_model=model).exists(): if not existing_models.filter(name=model_name).exists():
ChatModelOptions.objects.create( ChatModel.objects.create(
chat_model=model, name=model_name,
model_type=ChatModelOptions.ModelType.OPENAI, model_type=ChatModel.ModelType.OPENAI,
max_prompt_size=model_to_prompt_size.get(model), max_prompt_size=model_to_prompt_size.get(model_name),
vision_enabled=model in default_openai_chat_models, vision_enabled=model_name in default_openai_chat_models,
tokenizer=model_to_tokenizer.get(model), tokenizer=model_to_tokenizer.get(model_name),
ai_model_api=config, ai_model_api=config,
) )
@ -284,7 +284,7 @@ def initialization(interactive: bool = True):
except Exception as e: except Exception as e:
logger.error(f"🚨 Failed to create admin user: {e}", exc_info=True) logger.error(f"🚨 Failed to create admin user: {e}", exc_info=True)
chat_config = ConversationAdapters.get_default_conversation_config() chat_config = ConversationAdapters.get_default_chat_model()
if admin_user is None and chat_config is None: if admin_user is None and chat_config is None:
while True: while True:
try: try:

View file

@ -13,7 +13,7 @@ from khoj.configure import (
) )
from khoj.database.models import ( from khoj.database.models import (
Agent, Agent,
ChatModelOptions, ChatModel,
GithubConfig, GithubConfig,
GithubRepoConfig, GithubRepoConfig,
KhojApiUser, KhojApiUser,
@ -35,7 +35,7 @@ from khoj.utils.helpers import resolve_absolute_path
from khoj.utils.rawconfig import ContentConfig, ImageSearchConfig, SearchConfig from khoj.utils.rawconfig import ContentConfig, ImageSearchConfig, SearchConfig
from tests.helpers import ( from tests.helpers import (
AiModelApiFactory, AiModelApiFactory,
ChatModelOptionsFactory, ChatModelFactory,
ProcessLockFactory, ProcessLockFactory,
SubscriptionFactory, SubscriptionFactory,
UserConversationProcessorConfigFactory, UserConversationProcessorConfigFactory,
@ -184,14 +184,14 @@ def api_user4(default_user4):
@pytest.mark.django_db @pytest.mark.django_db
@pytest.fixture @pytest.fixture
def default_openai_chat_model_option(): def default_openai_chat_model_option():
chat_model = ChatModelOptionsFactory(chat_model="gpt-4o-mini", model_type="openai") chat_model = ChatModelFactory(name="gpt-4o-mini", model_type="openai")
return chat_model return chat_model
@pytest.mark.django_db @pytest.mark.django_db
@pytest.fixture @pytest.fixture
def offline_agent(): def offline_agent():
chat_model = ChatModelOptionsFactory() chat_model = ChatModelFactory()
return Agent.objects.create( return Agent.objects.create(
name="Accountant", name="Accountant",
chat_model=chat_model, chat_model=chat_model,
@ -202,7 +202,7 @@ def offline_agent():
@pytest.mark.django_db @pytest.mark.django_db
@pytest.fixture @pytest.fixture
def openai_agent(): def openai_agent():
chat_model = ChatModelOptionsFactory(chat_model="gpt-4o-mini", model_type="openai") chat_model = ChatModelFactory(name="gpt-4o-mini", model_type="openai")
return Agent.objects.create( return Agent.objects.create(
name="Accountant", name="Accountant",
chat_model=chat_model, chat_model=chat_model,
@ -311,13 +311,13 @@ def chat_client_builder(search_config, user, index_content=True, require_auth=Fa
# Initialize Processor from Config # Initialize Processor from Config
chat_provider = get_chat_provider() chat_provider = get_chat_provider()
online_chat_model: ChatModelOptionsFactory = None online_chat_model: ChatModelFactory = None
if chat_provider == ChatModelOptions.ModelType.OPENAI: if chat_provider == ChatModel.ModelType.OPENAI:
online_chat_model = ChatModelOptionsFactory(chat_model="gpt-4o-mini", model_type="openai") online_chat_model = ChatModelFactory(name="gpt-4o-mini", model_type="openai")
elif chat_provider == ChatModelOptions.ModelType.GOOGLE: elif chat_provider == ChatModel.ModelType.GOOGLE:
online_chat_model = ChatModelOptionsFactory(chat_model="gemini-1.5-flash", model_type="google") online_chat_model = ChatModelFactory(name="gemini-1.5-flash", model_type="google")
elif chat_provider == ChatModelOptions.ModelType.ANTHROPIC: elif chat_provider == ChatModel.ModelType.ANTHROPIC:
online_chat_model = ChatModelOptionsFactory(chat_model="claude-3-5-haiku-20241022", model_type="anthropic") online_chat_model = ChatModelFactory(name="claude-3-5-haiku-20241022", model_type="anthropic")
if online_chat_model: if online_chat_model:
online_chat_model.ai_model_api = AiModelApiFactory(api_key=get_chat_api_key(chat_provider)) online_chat_model.ai_model_api = AiModelApiFactory(api_key=get_chat_api_key(chat_provider))
UserConversationProcessorConfigFactory(user=user, setting=online_chat_model) UserConversationProcessorConfigFactory(user=user, setting=online_chat_model)
@ -394,8 +394,8 @@ def client_offline_chat(search_config: SearchConfig, default_user2: KhojUser):
configure_content(default_user2, all_files) configure_content(default_user2, all_files)
# Initialize Processor from Config # Initialize Processor from Config
ChatModelOptionsFactory( ChatModelFactory(
chat_model="bartowski/Meta-Llama-3.1-3B-Instruct-GGUF", name="bartowski/Meta-Llama-3.1-3B-Instruct-GGUF",
tokenizer=None, tokenizer=None,
max_prompt_size=None, max_prompt_size=None,
model_type="offline", model_type="offline",

View file

@ -6,7 +6,7 @@ from django.utils.timezone import make_aware
from khoj.database.models import ( from khoj.database.models import (
AiModelApi, AiModelApi,
ChatModelOptions, ChatModel,
Conversation, Conversation,
KhojApiUser, KhojApiUser,
KhojUser, KhojUser,
@ -18,27 +18,27 @@ from khoj.database.models import (
from khoj.processor.conversation.utils import message_to_log from khoj.processor.conversation.utils import message_to_log
def get_chat_provider(default: ChatModelOptions.ModelType | None = ChatModelOptions.ModelType.OFFLINE): def get_chat_provider(default: ChatModel.ModelType | None = ChatModel.ModelType.OFFLINE):
provider = os.getenv("KHOJ_TEST_CHAT_PROVIDER") provider = os.getenv("KHOJ_TEST_CHAT_PROVIDER")
if provider and provider in ChatModelOptions.ModelType: if provider and provider in ChatModel.ModelType:
return ChatModelOptions.ModelType(provider) return ChatModel.ModelType(provider)
elif os.getenv("OPENAI_API_KEY"): elif os.getenv("OPENAI_API_KEY"):
return ChatModelOptions.ModelType.OPENAI return ChatModel.ModelType.OPENAI
elif os.getenv("GEMINI_API_KEY"): elif os.getenv("GEMINI_API_KEY"):
return ChatModelOptions.ModelType.GOOGLE return ChatModel.ModelType.GOOGLE
elif os.getenv("ANTHROPIC_API_KEY"): elif os.getenv("ANTHROPIC_API_KEY"):
return ChatModelOptions.ModelType.ANTHROPIC return ChatModel.ModelType.ANTHROPIC
else: else:
return default return default
def get_chat_api_key(provider: ChatModelOptions.ModelType = None): def get_chat_api_key(provider: ChatModel.ModelType = None):
provider = provider or get_chat_provider() provider = provider or get_chat_provider()
if provider == ChatModelOptions.ModelType.OPENAI: if provider == ChatModel.ModelType.OPENAI:
return os.getenv("OPENAI_API_KEY") return os.getenv("OPENAI_API_KEY")
elif provider == ChatModelOptions.ModelType.GOOGLE: elif provider == ChatModel.ModelType.GOOGLE:
return os.getenv("GEMINI_API_KEY") return os.getenv("GEMINI_API_KEY")
elif provider == ChatModelOptions.ModelType.ANTHROPIC: elif provider == ChatModel.ModelType.ANTHROPIC:
return os.getenv("ANTHROPIC_API_KEY") return os.getenv("ANTHROPIC_API_KEY")
else: else:
return os.getenv("OPENAI_API_KEY") or os.getenv("GEMINI_API_KEY") or os.getenv("ANTHROPIC_API_KEY") return os.getenv("OPENAI_API_KEY") or os.getenv("GEMINI_API_KEY") or os.getenv("ANTHROPIC_API_KEY")
@ -83,13 +83,13 @@ class AiModelApiFactory(factory.django.DjangoModelFactory):
api_key = get_chat_api_key() api_key = get_chat_api_key()
class ChatModelOptionsFactory(factory.django.DjangoModelFactory): class ChatModelFactory(factory.django.DjangoModelFactory):
class Meta: class Meta:
model = ChatModelOptions model = ChatModel
max_prompt_size = 20000 max_prompt_size = 20000
tokenizer = None tokenizer = None
chat_model = "bartowski/Meta-Llama-3.2-3B-Instruct-GGUF" name = "bartowski/Meta-Llama-3.2-3B-Instruct-GGUF"
model_type = get_chat_provider() model_type = get_chat_provider()
ai_model_api = factory.LazyAttribute(lambda obj: AiModelApiFactory() if get_chat_api_key() else None) ai_model_api = factory.LazyAttribute(lambda obj: AiModelApiFactory() if get_chat_api_key() else None)
@ -99,7 +99,7 @@ class UserConversationProcessorConfigFactory(factory.django.DjangoModelFactory):
model = UserConversationConfig model = UserConversationConfig
user = factory.SubFactory(UserFactory) user = factory.SubFactory(UserFactory)
setting = factory.SubFactory(ChatModelOptionsFactory) setting = factory.SubFactory(ChatModelFactory)
class ConversationFactory(factory.django.DjangoModelFactory): class ConversationFactory(factory.django.DjangoModelFactory):

View file

@ -5,14 +5,14 @@ import pytest
from asgiref.sync import sync_to_async from asgiref.sync import sync_to_async
from khoj.database.adapters import AgentAdapters from khoj.database.adapters import AgentAdapters
from khoj.database.models import Agent, ChatModelOptions, Entry, KhojUser from khoj.database.models import Agent, ChatModel, Entry, KhojUser
from khoj.routers.api import execute_search from khoj.routers.api import execute_search
from khoj.utils.helpers import get_absolute_path from khoj.utils.helpers import get_absolute_path
from tests.helpers import ChatModelOptionsFactory from tests.helpers import ChatModelFactory
def test_create_default_agent(default_user: KhojUser): def test_create_default_agent(default_user: KhojUser):
ChatModelOptionsFactory() ChatModelFactory()
agent = AgentAdapters.create_default_agent(default_user) agent = AgentAdapters.create_default_agent(default_user)
assert agent is not None assert agent is not None
@ -24,7 +24,7 @@ def test_create_default_agent(default_user: KhojUser):
@pytest.mark.anyio @pytest.mark.anyio
@pytest.mark.django_db(transaction=True) @pytest.mark.django_db(transaction=True)
async def test_create_or_update_agent(default_user: KhojUser, default_openai_chat_model_option: ChatModelOptions): async def test_create_or_update_agent(default_user: KhojUser, default_openai_chat_model_option: ChatModel):
new_agent = await AgentAdapters.aupdate_agent( new_agent = await AgentAdapters.aupdate_agent(
default_user, default_user,
"Test Agent", "Test Agent",
@ -32,7 +32,7 @@ async def test_create_or_update_agent(default_user: KhojUser, default_openai_cha
Agent.PrivacyLevel.PRIVATE, Agent.PrivacyLevel.PRIVATE,
"icon", "icon",
"color", "color",
default_openai_chat_model_option.chat_model, default_openai_chat_model_option.name,
[], [],
[], [],
[], [],
@ -46,7 +46,7 @@ async def test_create_or_update_agent(default_user: KhojUser, default_openai_cha
@pytest.mark.anyio @pytest.mark.anyio
@pytest.mark.django_db(transaction=True) @pytest.mark.django_db(transaction=True)
async def test_create_or_update_agent_with_knowledge_base( async def test_create_or_update_agent_with_knowledge_base(
default_user2: KhojUser, default_openai_chat_model_option: ChatModelOptions, chat_client default_user2: KhojUser, default_openai_chat_model_option: ChatModel, chat_client
): ):
full_filename = get_absolute_path("tests/data/markdown/having_kids.markdown") full_filename = get_absolute_path("tests/data/markdown/having_kids.markdown")
new_agent = await AgentAdapters.aupdate_agent( new_agent = await AgentAdapters.aupdate_agent(
@ -56,7 +56,7 @@ async def test_create_or_update_agent_with_knowledge_base(
Agent.PrivacyLevel.PRIVATE, Agent.PrivacyLevel.PRIVATE,
"icon", "icon",
"color", "color",
default_openai_chat_model_option.chat_model, default_openai_chat_model_option.name,
[full_filename], [full_filename],
[], [],
[], [],
@ -78,7 +78,7 @@ async def test_create_or_update_agent_with_knowledge_base(
@pytest.mark.anyio @pytest.mark.anyio
@pytest.mark.django_db(transaction=True) @pytest.mark.django_db(transaction=True)
async def test_create_or_update_agent_with_knowledge_base_and_search( async def test_create_or_update_agent_with_knowledge_base_and_search(
default_user2: KhojUser, default_openai_chat_model_option: ChatModelOptions, chat_client default_user2: KhojUser, default_openai_chat_model_option: ChatModel, chat_client
): ):
full_filename = get_absolute_path("tests/data/markdown/having_kids.markdown") full_filename = get_absolute_path("tests/data/markdown/having_kids.markdown")
new_agent = await AgentAdapters.aupdate_agent( new_agent = await AgentAdapters.aupdate_agent(
@ -88,7 +88,7 @@ async def test_create_or_update_agent_with_knowledge_base_and_search(
Agent.PrivacyLevel.PRIVATE, Agent.PrivacyLevel.PRIVATE,
"icon", "icon",
"color", "color",
default_openai_chat_model_option.chat_model, default_openai_chat_model_option.name,
[full_filename], [full_filename],
[], [],
[], [],
@ -102,7 +102,7 @@ async def test_create_or_update_agent_with_knowledge_base_and_search(
@pytest.mark.anyio @pytest.mark.anyio
@pytest.mark.django_db(transaction=True) @pytest.mark.django_db(transaction=True)
async def test_agent_with_knowledge_base_and_search_not_creator( async def test_agent_with_knowledge_base_and_search_not_creator(
default_user2: KhojUser, default_openai_chat_model_option: ChatModelOptions, chat_client, default_user3: KhojUser default_user2: KhojUser, default_openai_chat_model_option: ChatModel, chat_client, default_user3: KhojUser
): ):
full_filename = get_absolute_path("tests/data/markdown/having_kids.markdown") full_filename = get_absolute_path("tests/data/markdown/having_kids.markdown")
new_agent = await AgentAdapters.aupdate_agent( new_agent = await AgentAdapters.aupdate_agent(
@ -112,7 +112,7 @@ async def test_agent_with_knowledge_base_and_search_not_creator(
Agent.PrivacyLevel.PUBLIC, Agent.PrivacyLevel.PUBLIC,
"icon", "icon",
"color", "color",
default_openai_chat_model_option.chat_model, default_openai_chat_model_option.name,
[full_filename], [full_filename],
[], [],
[], [],
@ -126,7 +126,7 @@ async def test_agent_with_knowledge_base_and_search_not_creator(
@pytest.mark.anyio @pytest.mark.anyio
@pytest.mark.django_db(transaction=True) @pytest.mark.django_db(transaction=True)
async def test_agent_with_knowledge_base_and_search_not_creator_and_private( async def test_agent_with_knowledge_base_and_search_not_creator_and_private(
default_user2: KhojUser, default_openai_chat_model_option: ChatModelOptions, chat_client, default_user3: KhojUser default_user2: KhojUser, default_openai_chat_model_option: ChatModel, chat_client, default_user3: KhojUser
): ):
full_filename = get_absolute_path("tests/data/markdown/having_kids.markdown") full_filename = get_absolute_path("tests/data/markdown/having_kids.markdown")
new_agent = await AgentAdapters.aupdate_agent( new_agent = await AgentAdapters.aupdate_agent(
@ -136,7 +136,7 @@ async def test_agent_with_knowledge_base_and_search_not_creator_and_private(
Agent.PrivacyLevel.PRIVATE, Agent.PrivacyLevel.PRIVATE,
"icon", "icon",
"color", "color",
default_openai_chat_model_option.chat_model, default_openai_chat_model_option.name,
[full_filename], [full_filename],
[], [],
[], [],
@ -150,7 +150,7 @@ async def test_agent_with_knowledge_base_and_search_not_creator_and_private(
@pytest.mark.anyio @pytest.mark.anyio
@pytest.mark.django_db(transaction=True) @pytest.mark.django_db(transaction=True)
async def test_agent_with_knowledge_base_and_search_not_creator_and_private_accessible_to_none( async def test_agent_with_knowledge_base_and_search_not_creator_and_private_accessible_to_none(
default_user2: KhojUser, default_openai_chat_model_option: ChatModelOptions, chat_client default_user2: KhojUser, default_openai_chat_model_option: ChatModel, chat_client
): ):
full_filename = get_absolute_path("tests/data/markdown/having_kids.markdown") full_filename = get_absolute_path("tests/data/markdown/having_kids.markdown")
new_agent = await AgentAdapters.aupdate_agent( new_agent = await AgentAdapters.aupdate_agent(
@ -160,7 +160,7 @@ async def test_agent_with_knowledge_base_and_search_not_creator_and_private_acce
Agent.PrivacyLevel.PRIVATE, Agent.PrivacyLevel.PRIVATE,
"icon", "icon",
"color", "color",
default_openai_chat_model_option.chat_model, default_openai_chat_model_option.name,
[full_filename], [full_filename],
[], [],
[], [],
@ -174,7 +174,7 @@ async def test_agent_with_knowledge_base_and_search_not_creator_and_private_acce
@pytest.mark.anyio @pytest.mark.anyio
@pytest.mark.django_db(transaction=True) @pytest.mark.django_db(transaction=True)
async def test_multiple_agents_with_knowledge_base_and_users( async def test_multiple_agents_with_knowledge_base_and_users(
default_user2: KhojUser, default_openai_chat_model_option: ChatModelOptions, chat_client, default_user3: KhojUser default_user2: KhojUser, default_openai_chat_model_option: ChatModel, chat_client, default_user3: KhojUser
): ):
full_filename = get_absolute_path("tests/data/markdown/having_kids.markdown") full_filename = get_absolute_path("tests/data/markdown/having_kids.markdown")
new_agent = await AgentAdapters.aupdate_agent( new_agent = await AgentAdapters.aupdate_agent(
@ -184,7 +184,7 @@ async def test_multiple_agents_with_knowledge_base_and_users(
Agent.PrivacyLevel.PUBLIC, Agent.PrivacyLevel.PUBLIC,
"icon", "icon",
"color", "color",
default_openai_chat_model_option.chat_model, default_openai_chat_model_option.name,
[full_filename], [full_filename],
[], [],
[], [],
@ -198,7 +198,7 @@ async def test_multiple_agents_with_knowledge_base_and_users(
Agent.PrivacyLevel.PUBLIC, Agent.PrivacyLevel.PUBLIC,
"icon", "icon",
"color", "color",
default_openai_chat_model_option.chat_model, default_openai_chat_model_option.name,
[full_filename2], [full_filename2],
[], [],
[], [],

View file

@ -2,12 +2,12 @@ from datetime import datetime
import pytest import pytest
from khoj.database.models import ChatModelOptions from khoj.database.models import ChatModel
from khoj.routers.helpers import aget_data_sources_and_output_format from khoj.routers.helpers import aget_data_sources_and_output_format
from khoj.utils.helpers import ConversationCommand from khoj.utils.helpers import ConversationCommand
from tests.helpers import ConversationFactory, generate_chat_history, get_chat_provider from tests.helpers import ConversationFactory, generate_chat_history, get_chat_provider
SKIP_TESTS = get_chat_provider(default=None) != ChatModelOptions.ModelType.OFFLINE SKIP_TESTS = get_chat_provider(default=None) != ChatModel.ModelType.OFFLINE
pytestmark = pytest.mark.skipif( pytestmark = pytest.mark.skipif(
SKIP_TESTS, SKIP_TESTS,
reason="Disable in CI to avoid long test runs.", reason="Disable in CI to avoid long test runs.",

View file

@ -4,12 +4,12 @@ import pytest
from faker import Faker from faker import Faker
from freezegun import freeze_time from freezegun import freeze_time
from khoj.database.models import Agent, ChatModelOptions, Entry, KhojUser from khoj.database.models import Agent, ChatModel, Entry, KhojUser
from khoj.processor.conversation import prompts from khoj.processor.conversation import prompts
from khoj.processor.conversation.utils import message_to_log from khoj.processor.conversation.utils import message_to_log
from tests.helpers import ConversationFactory, get_chat_provider from tests.helpers import ConversationFactory, get_chat_provider
SKIP_TESTS = get_chat_provider(default=None) != ChatModelOptions.ModelType.OFFLINE SKIP_TESTS = get_chat_provider(default=None) != ChatModel.ModelType.OFFLINE
pytestmark = pytest.mark.skipif( pytestmark = pytest.mark.skipif(
SKIP_TESTS, SKIP_TESTS,
reason="Disable in CI to avoid long test runs.", reason="Disable in CI to avoid long test runs.",