diff --git a/.github/workflows/run_evals.yml b/.github/workflows/run_evals.yml index dc8c89b7..836d03b1 100644 --- a/.github/workflows/run_evals.yml +++ b/.github/workflows/run_evals.yml @@ -113,7 +113,8 @@ jobs: khoj --anonymous-mode --non-interactive & # Start code sandbox - npm run dev --prefix terrarium & + npm install -g pm2 + npm run ci --prefix terrarium # Wait for server to be ready timeout=120 diff --git a/documentation/docs/advanced/litellm.md b/documentation/docs/advanced/litellm.md index 212ac047..7700836d 100644 --- a/documentation/docs/advanced/litellm.md +++ b/documentation/docs/advanced/litellm.md @@ -25,7 +25,7 @@ Using LiteLLM with Khoj makes it possible to turn any LLM behind an API into you - Name: `proxy-name` - Api Key: `any string` - Api Base Url: **URL of your Openai Proxy API** -4. Create a new [Chat Model Option](http://localhost:42110/server/admin/database/chatmodeloptions/add) on your Khoj admin panel. +4. Create a new [Chat Model](http://localhost:42110/server/admin/database/chatmodel/add) on your Khoj admin panel. - Name: `llama3.1` (replace with the name of your local model) - Model Type: `Openai` - Openai Config: `` diff --git a/documentation/docs/advanced/lmstudio.md b/documentation/docs/advanced/lmstudio.md index 5c5ab567..1ecd7f06 100644 --- a/documentation/docs/advanced/lmstudio.md +++ b/documentation/docs/advanced/lmstudio.md @@ -18,7 +18,7 @@ LM Studio can expose an [OpenAI API compatible server](https://lmstudio.ai/docs/ - Name: `proxy-name` - Api Key: `any string` - Api Base Url: `http://localhost:1234/v1/` (default for LMStudio) -4. Create a new [Chat Model Option](http://localhost:42110/server/admin/database/chatmodeloptions/add) on your Khoj admin panel. +4. Create a new [Chat Model](http://localhost:42110/server/admin/database/chatmodel/add) on your Khoj admin panel. - Name: `llama3.1` (replace with the name of your local model) - Model Type: `Openai` - Openai Config: `` diff --git a/documentation/docs/advanced/ollama.mdx b/documentation/docs/advanced/ollama.mdx index 78d77d26..486357e8 100644 --- a/documentation/docs/advanced/ollama.mdx +++ b/documentation/docs/advanced/ollama.mdx @@ -64,7 +64,7 @@ Restart your Khoj server after first run or update to the settings below to ensu - Name: `ollama` - Api Key: `any string` - Api Base Url: `http://localhost:11434/v1/` (default for Ollama) - 4. Create a new [Chat Model Option](http://localhost:42110/server/admin/database/chatmodeloptions/add) on your Khoj admin panel. + 4. Create a new [Chat Model](http://localhost:42110/server/admin/database/chatmodel/add) on your Khoj admin panel. - Name: `llama3.1` (replace with the name of your local model) - Model Type: `Openai` - Openai Config: `` diff --git a/documentation/docs/advanced/use-openai-proxy.md b/documentation/docs/advanced/use-openai-proxy.md index 6efaad1c..65993cb1 100644 --- a/documentation/docs/advanced/use-openai-proxy.md +++ b/documentation/docs/advanced/use-openai-proxy.md @@ -25,7 +25,7 @@ For specific integrations, see our [Ollama](/advanced/ollama), [LMStudio](/advan - Name: `any name` - Api Key: `any string` - Api Base Url: **URL of your Openai Proxy API** -3. Create a new [Chat Model Option](http://localhost:42110/server/admin/database/chatmodeloptions/add) on your Khoj admin panel. +3. Create a new [Chat Model](http://localhost:42110/server/admin/database/chatmodel/add) on your Khoj admin panel. - Name: `llama3` (replace with the name of your local model) - Model Type: `Openai` - Openai Config: `` diff --git a/documentation/docs/features/image_generation.md b/documentation/docs/features/image_generation.md index 14a005b3..d277d592 100644 --- a/documentation/docs/features/image_generation.md +++ b/documentation/docs/features/image_generation.md @@ -17,7 +17,7 @@ You have a couple of image generation options. We support most state of the art image generation models, including Ideogram, Flux, and Stable Diffusion. These will run using [Replicate](https://replicate.com). Here's how to set them up: 1. Get a Replicate API key [here](https://replicate.com/account/api-tokens). -1. Create a new [Text to Image Model](https://app.khoj.dev/server/admin/database/texttoimagemodelconfig/). Set the `type` to `Replicate`. Use any of the model names you see [on this list](https://replicate.com/pricing#image-models). +1. Create a new [Text to Image Model](http://localhost:42110/server/admin/database/texttoimagemodelconfig/). Set the `type` to `Replicate`. Use any of the model names you see [on this list](https://replicate.com/pricing#image-models). ### OpenAI diff --git a/documentation/docs/get-started/setup.mdx b/documentation/docs/get-started/setup.mdx index a3081e96..86be2688 100644 --- a/documentation/docs/get-started/setup.mdx +++ b/documentation/docs/get-started/setup.mdx @@ -307,7 +307,7 @@ Using Ollama? See the [Ollama Integration](/advanced/ollama) section for more cu - Give the configuration a friendly name like `OpenAI` - (Optional) Set the API base URL. It is only relevant if you're using another OpenAI-compatible proxy server like [Ollama](/advanced/ollama) or [LMStudio](/advanced/lmstudio).
![example configuration for ai model api](/img/example_openai_processor_config.png) -2. Create a new [chat model options](http://localhost:42110/server/admin/database/chatmodeloptions/add) +2. Create a new [chat model](http://localhost:42110/server/admin/database/chatmodel/add) - Set the `chat-model` field to an [OpenAI chat model](https://platform.openai.com/docs/models). Example: `gpt-4o`. - Make sure to set the `model-type` field to `OpenAI`. - If your model supports vision, set the `vision enabled` field to `true`. This is currently only supported for OpenAI models with vision capabilities. @@ -318,7 +318,7 @@ Using Ollama? See the [Ollama Integration](/advanced/ollama) section for more cu 1. Create a new [AI Model API](http://localhost:42110/server/admin/database/aimodelapi/add) in the server admin settings. - Add your [Anthropic API key](https://console.anthropic.com/account/keys) - Give the configuration a friendly name like `Anthropic`. Do not configure the API base url. -2. Create a new [chat model options](http://localhost:42110/server/admin/database/chatmodeloptions/add) +2. Create a new [chat model](http://localhost:42110/server/admin/database/chatmodel/add) - Set the `chat-model` field to an [Anthropic chat model](https://docs.anthropic.com/en/docs/about-claude/models#model-names). Example: `claude-3-5-sonnet-20240620`. - Set the `model-type` field to `Anthropic`. - Set the `ai model api` field to the Anthropic AI Model API you created in step 1. @@ -327,7 +327,7 @@ Using Ollama? See the [Ollama Integration](/advanced/ollama) section for more cu 1. Create a new [AI Model API](http://localhost:42110/server/admin/database/aimodelapi/add) in the server admin settings. - Add your [Gemini API key](https://aistudio.google.com/app/apikey) - Give the configuration a friendly name like `Gemini`. Do not configure the API base url. -2. Create a new [chat model options](http://localhost:42110/server/admin/database/chatmodeloptions/add) +2. Create a new [chat model](http://localhost:42110/server/admin/database/chatmodel/add) - Set the `chat-model` field to a [Google Gemini chat model](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#gemini-models). Example: `gemini-1.5-flash`. - Set the `model-type` field to `Gemini`. - Set the `ai model api` field to the Gemini AI Model API you created in step 1. @@ -343,7 +343,7 @@ Offline chat stays completely private and can work without internet using any op ::: 1. Get the name of your preferred chat model from [HuggingFace](https://huggingface.co/models?pipeline_tag=text-generation&library=gguf). *Most GGUF format chat models are supported*. -2. Open the [create chat model page](http://localhost:42110/server/admin/database/chatmodeloptions/add/) on the admin panel +2. Open the [create chat model page](http://localhost:42110/server/admin/database/chatmodel/add/) on the admin panel 3. Set the `chat-model` field to the name of your preferred chat model - Make sure the `model-type` is set to `Offline` 4. Set the newly added chat model as your preferred model in your [User chat settings](http://localhost:42110/settings) and [Server chat settings](http://localhost:42110/server/admin/database/serverchatsettings/). diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index 9dee6684..6f08d986 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -36,7 +36,7 @@ from torch import Tensor from khoj.database.models import ( Agent, AiModelApi, - ChatModelOptions, + ChatModel, ClientApplication, Conversation, Entry, @@ -736,8 +736,8 @@ class AgentAdapters: @staticmethod def create_default_agent(user: KhojUser): - default_conversation_config = ConversationAdapters.get_default_conversation_config(user) - if default_conversation_config is None: + default_chat_model = ConversationAdapters.get_default_chat_model(user) + if default_chat_model is None: logger.info("No default conversation config found, skipping default agent creation") return None default_personality = prompts.personality.format(current_date="placeholder", day_of_week="placeholder") @@ -746,7 +746,7 @@ class AgentAdapters: if agent: agent.personality = default_personality - agent.chat_model = default_conversation_config + agent.chat_model = default_chat_model agent.slug = AgentAdapters.DEFAULT_AGENT_SLUG agent.name = AgentAdapters.DEFAULT_AGENT_NAME agent.privacy_level = Agent.PrivacyLevel.PUBLIC @@ -760,7 +760,7 @@ class AgentAdapters: name=AgentAdapters.DEFAULT_AGENT_NAME, privacy_level=Agent.PrivacyLevel.PUBLIC, managed_by_admin=True, - chat_model=default_conversation_config, + chat_model=default_chat_model, personality=default_personality, slug=AgentAdapters.DEFAULT_AGENT_SLUG, ) @@ -787,7 +787,7 @@ class AgentAdapters: output_modes: List[str], slug: Optional[str] = None, ): - chat_model_option = await ChatModelOptions.objects.filter(chat_model=chat_model).afirst() + chat_model_option = await ChatModel.objects.filter(name=chat_model).afirst() # Slug will be None for new agents, which will trigger a new agent creation with a generated, immutable slug agent, created = await Agent.objects.filter(slug=slug, creator=user).aupdate_or_create( @@ -972,29 +972,29 @@ class ConversationAdapters: @staticmethod @require_valid_user - def has_any_conversation_config(user: KhojUser): - return ChatModelOptions.objects.filter(user=user).exists() + def has_any_chat_model(user: KhojUser): + return ChatModel.objects.filter(user=user).exists() @staticmethod - def get_all_conversation_configs(): - return ChatModelOptions.objects.all() + def get_all_chat_models(): + return ChatModel.objects.all() @staticmethod - async def aget_all_conversation_configs(): - return await sync_to_async(list)(ChatModelOptions.objects.prefetch_related("ai_model_api").all()) + async def aget_all_chat_models(): + return await sync_to_async(list)(ChatModel.objects.prefetch_related("ai_model_api").all()) @staticmethod def get_vision_enabled_config(): - conversation_configurations = ConversationAdapters.get_all_conversation_configs() - for config in conversation_configurations: + chat_models = ConversationAdapters.get_all_chat_models() + for config in chat_models: if config.vision_enabled: return config return None @staticmethod async def aget_vision_enabled_config(): - conversation_configurations = await ConversationAdapters.aget_all_conversation_configs() - for config in conversation_configurations: + chat_models = await ConversationAdapters.aget_all_chat_models() + for config in chat_models: if config.vision_enabled: return config return None @@ -1010,7 +1010,7 @@ class ConversationAdapters: @staticmethod @arequire_valid_user async def aset_user_conversation_processor(user: KhojUser, conversation_processor_config_id: int): - config = await ChatModelOptions.objects.filter(id=conversation_processor_config_id).afirst() + config = await ChatModel.objects.filter(id=conversation_processor_config_id).afirst() if not config: return None new_config = await UserConversationConfig.objects.aupdate_or_create(user=user, defaults={"setting": config}) @@ -1026,24 +1026,24 @@ class ConversationAdapters: return new_config @staticmethod - def get_conversation_config(user: KhojUser): + def get_chat_model(user: KhojUser): subscribed = is_user_subscribed(user) if not subscribed: - return ConversationAdapters.get_default_conversation_config(user) + return ConversationAdapters.get_default_chat_model(user) config = UserConversationConfig.objects.filter(user=user).first() if config: return config.setting - return ConversationAdapters.get_advanced_conversation_config(user) + return ConversationAdapters.get_advanced_chat_model(user) @staticmethod - async def aget_conversation_config(user: KhojUser): + async def aget_chat_model(user: KhojUser): subscribed = await ais_user_subscribed(user) if not subscribed: - return await ConversationAdapters.aget_default_conversation_config(user) + return await ConversationAdapters.aget_default_chat_model(user) config = await UserConversationConfig.objects.filter(user=user).prefetch_related("setting").afirst() if config: return config.setting - return ConversationAdapters.aget_advanced_conversation_config(user) + return ConversationAdapters.aget_advanced_chat_model(user) @staticmethod async def aget_voice_model_config(user: KhojUser) -> Optional[VoiceModelOption]: @@ -1064,7 +1064,7 @@ class ConversationAdapters: return VoiceModelOption.objects.first() @staticmethod - def get_default_conversation_config(user: KhojUser = None): + def get_default_chat_model(user: KhojUser = None): """Get default conversation config. Prefer chat model by server admin > user > first created chat model""" # Get the server chat settings server_chat_settings = ServerChatSettings.objects.first() @@ -1084,10 +1084,10 @@ class ConversationAdapters: return user_chat_settings.setting # Get the first chat model if even the user chat settings are not set - return ChatModelOptions.objects.filter().first() + return ChatModel.objects.filter().first() @staticmethod - async def aget_default_conversation_config(user: KhojUser = None): + async def aget_default_chat_model(user: KhojUser = None): """Get default conversation config. Prefer chat model by server admin > user > first created chat model""" # Get the server chat settings server_chat_settings: ServerChatSettings = ( @@ -1117,17 +1117,17 @@ class ConversationAdapters: return user_chat_settings.setting # Get the first chat model if even the user chat settings are not set - return await ChatModelOptions.objects.filter().prefetch_related("ai_model_api").afirst() + return await ChatModel.objects.filter().prefetch_related("ai_model_api").afirst() @staticmethod - def get_advanced_conversation_config(user: KhojUser): + def get_advanced_chat_model(user: KhojUser): server_chat_settings = ServerChatSettings.objects.first() if server_chat_settings is not None and server_chat_settings.chat_advanced is not None: return server_chat_settings.chat_advanced - return ConversationAdapters.get_default_conversation_config(user) + return ConversationAdapters.get_default_chat_model(user) @staticmethod - async def aget_advanced_conversation_config(user: KhojUser = None): + async def aget_advanced_chat_model(user: KhojUser = None): server_chat_settings: ServerChatSettings = ( await ServerChatSettings.objects.filter() .prefetch_related("chat_advanced", "chat_advanced__ai_model_api") @@ -1135,7 +1135,7 @@ class ConversationAdapters: ) if server_chat_settings is not None and server_chat_settings.chat_advanced is not None: return server_chat_settings.chat_advanced - return await ConversationAdapters.aget_default_conversation_config(user) + return await ConversationAdapters.aget_default_chat_model(user) @staticmethod async def aget_server_webscraper(): @@ -1247,16 +1247,16 @@ class ConversationAdapters: @staticmethod def get_conversation_processor_options(): - return ChatModelOptions.objects.all() + return ChatModel.objects.all() @staticmethod - def set_conversation_processor_config(user: KhojUser, new_config: ChatModelOptions): + def set_user_chat_model(user: KhojUser, chat_model: ChatModel): user_conversation_config, _ = UserConversationConfig.objects.get_or_create(user=user) - user_conversation_config.setting = new_config + user_conversation_config.setting = chat_model user_conversation_config.save() @staticmethod - async def aget_user_conversation_config(user: KhojUser): + async def aget_user_chat_model(user: KhojUser): config = ( await UserConversationConfig.objects.filter(user=user).prefetch_related("setting__ai_model_api").afirst() ) @@ -1288,33 +1288,33 @@ class ConversationAdapters: return random.sample(all_questions, max_results) @staticmethod - def get_valid_conversation_config(user: KhojUser, conversation: Conversation): + def get_valid_chat_model(user: KhojUser, conversation: Conversation): agent: Agent = conversation.agent if AgentAdapters.get_default_agent() != conversation.agent else None if agent and agent.chat_model: - conversation_config = conversation.agent.chat_model + chat_model = conversation.agent.chat_model else: - conversation_config = ConversationAdapters.get_conversation_config(user) + chat_model = ConversationAdapters.get_chat_model(user) - if conversation_config is None: - conversation_config = ConversationAdapters.get_default_conversation_config() + if chat_model is None: + chat_model = ConversationAdapters.get_default_chat_model() - if conversation_config.model_type == ChatModelOptions.ModelType.OFFLINE: + if chat_model.model_type == ChatModel.ModelType.OFFLINE: if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None: - chat_model = conversation_config.chat_model - max_tokens = conversation_config.max_prompt_size - state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens) + chat_model_name = chat_model.name + max_tokens = chat_model.max_prompt_size + state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model_name, max_tokens) - return conversation_config + return chat_model if ( - conversation_config.model_type + chat_model.model_type in [ - ChatModelOptions.ModelType.ANTHROPIC, - ChatModelOptions.ModelType.OPENAI, - ChatModelOptions.ModelType.GOOGLE, + ChatModel.ModelType.ANTHROPIC, + ChatModel.ModelType.OPENAI, + ChatModel.ModelType.GOOGLE, ] - ) and conversation_config.ai_model_api: - return conversation_config + ) and chat_model.ai_model_api: + return chat_model else: raise ValueError("Invalid conversation config - either configure offline chat or openai chat") diff --git a/src/khoj/database/admin.py b/src/khoj/database/admin.py index ce1060c1..fb46f973 100644 --- a/src/khoj/database/admin.py +++ b/src/khoj/database/admin.py @@ -16,7 +16,7 @@ from unfold import admin as unfold_admin from khoj.database.models import ( Agent, AiModelApi, - ChatModelOptions, + ChatModel, ClientApplication, Conversation, Entry, @@ -212,15 +212,15 @@ class KhojUserSubscription(unfold_admin.ModelAdmin): list_filter = ("type",) -@admin.register(ChatModelOptions) -class ChatModelOptionsAdmin(unfold_admin.ModelAdmin): +@admin.register(ChatModel) +class ChatModelAdmin(unfold_admin.ModelAdmin): list_display = ( "id", - "chat_model", + "name", "ai_model_api", "max_prompt_size", ) - search_fields = ("id", "chat_model", "ai_model_api__name") + search_fields = ("id", "name", "ai_model_api__name") @admin.register(TextToImageModelConfig) @@ -385,7 +385,7 @@ class UserConversationConfigAdmin(unfold_admin.ModelAdmin): "get_chat_model", "get_subscription_type", ) - search_fields = ("id", "user__email", "setting__chat_model", "user__subscription__type") + search_fields = ("id", "user__email", "setting__name", "user__subscription__type") ordering = ("-updated_at",) def get_user_email(self, obj): @@ -395,10 +395,10 @@ class UserConversationConfigAdmin(unfold_admin.ModelAdmin): get_user_email.admin_order_field = "user__email" # type: ignore def get_chat_model(self, obj): - return obj.setting.chat_model if obj.setting else None + return obj.setting.name if obj.setting else None get_chat_model.short_description = "Chat Model" # type: ignore - get_chat_model.admin_order_field = "setting__chat_model" # type: ignore + get_chat_model.admin_order_field = "setting__name" # type: ignore def get_subscription_type(self, obj): if hasattr(obj.user, "subscription"): diff --git a/src/khoj/database/migrations/0077_chatmodel_alter_agent_chat_model_and_more.py b/src/khoj/database/migrations/0077_chatmodel_alter_agent_chat_model_and_more.py new file mode 100644 index 00000000..a8e1ebde --- /dev/null +++ b/src/khoj/database/migrations/0077_chatmodel_alter_agent_chat_model_and_more.py @@ -0,0 +1,62 @@ +# Generated by Django 5.0.9 on 2024-12-09 04:21 + +import django.db.models.deletion +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("database", "0076_rename_openaiprocessorconversationconfig_aimodelapi_and_more"), + ] + + operations = [ + migrations.RenameModel( + old_name="ChatModelOptions", + new_name="ChatModel", + ), + migrations.RenameField( + model_name="chatmodel", + old_name="chat_model", + new_name="name", + ), + migrations.AlterField( + model_name="agent", + name="chat_model", + field=models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to="database.chatmodel"), + ), + migrations.AlterField( + model_name="serverchatsettings", + name="chat_advanced", + field=models.ForeignKey( + blank=True, + default=None, + null=True, + on_delete=django.db.models.deletion.CASCADE, + related_name="chat_advanced", + to="database.chatmodel", + ), + ), + migrations.AlterField( + model_name="serverchatsettings", + name="chat_default", + field=models.ForeignKey( + blank=True, + default=None, + null=True, + on_delete=django.db.models.deletion.CASCADE, + related_name="chat_default", + to="database.chatmodel", + ), + ), + migrations.AlterField( + model_name="userconversationconfig", + name="setting", + field=models.ForeignKey( + blank=True, + default=None, + null=True, + on_delete=django.db.models.deletion.CASCADE, + to="database.chatmodel", + ), + ), + ] diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py index 2b1f5022..12a303d3 100644 --- a/src/khoj/database/models/__init__.py +++ b/src/khoj/database/models/__init__.py @@ -193,7 +193,7 @@ class AiModelApi(DbBaseModel): return self.name -class ChatModelOptions(DbBaseModel): +class ChatModel(DbBaseModel): class ModelType(models.TextChoices): OPENAI = "openai" OFFLINE = "offline" @@ -203,13 +203,13 @@ class ChatModelOptions(DbBaseModel): max_prompt_size = models.IntegerField(default=None, null=True, blank=True) subscribed_max_prompt_size = models.IntegerField(default=None, null=True, blank=True) tokenizer = models.CharField(max_length=200, default=None, null=True, blank=True) - chat_model = models.CharField(max_length=200, default="bartowski/Meta-Llama-3.1-8B-Instruct-GGUF") + name = models.CharField(max_length=200, default="bartowski/Meta-Llama-3.1-8B-Instruct-GGUF") model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE) vision_enabled = models.BooleanField(default=False) ai_model_api = models.ForeignKey(AiModelApi, on_delete=models.CASCADE, default=None, null=True, blank=True) def __str__(self): - return self.chat_model + return self.name class VoiceModelOption(DbBaseModel): @@ -297,7 +297,7 @@ class Agent(DbBaseModel): models.CharField(max_length=200, choices=OutputModeOptions.choices), default=list, null=True, blank=True ) managed_by_admin = models.BooleanField(default=False) - chat_model = models.ForeignKey(ChatModelOptions, on_delete=models.CASCADE) + chat_model = models.ForeignKey(ChatModel, on_delete=models.CASCADE) slug = models.CharField(max_length=200, unique=True) style_color = models.CharField(max_length=200, choices=StyleColorTypes.choices, default=StyleColorTypes.BLUE) style_icon = models.CharField(max_length=200, choices=StyleIconTypes.choices, default=StyleIconTypes.LIGHTBULB) @@ -438,10 +438,10 @@ class WebScraper(DbBaseModel): class ServerChatSettings(DbBaseModel): chat_default = models.ForeignKey( - ChatModelOptions, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="chat_default" + ChatModel, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="chat_default" ) chat_advanced = models.ForeignKey( - ChatModelOptions, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="chat_advanced" + ChatModel, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="chat_advanced" ) web_scraper = models.ForeignKey( WebScraper, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="web_scraper" @@ -563,7 +563,7 @@ class SpeechToTextModelOptions(DbBaseModel): class UserConversationConfig(DbBaseModel): user = models.OneToOneField(KhojUser, on_delete=models.CASCADE) - setting = models.ForeignKey(ChatModelOptions, on_delete=models.CASCADE, default=None, null=True, blank=True) + setting = models.ForeignKey(ChatModel, on_delete=models.CASCADE, default=None, null=True, blank=True) class UserVoiceModelConfig(DbBaseModel): diff --git a/src/khoj/migrations/migrate_server_pg.py b/src/khoj/migrations/migrate_server_pg.py index a46664c3..316704b9 100644 --- a/src/khoj/migrations/migrate_server_pg.py +++ b/src/khoj/migrations/migrate_server_pg.py @@ -60,7 +60,7 @@ import logging from packaging import version -from khoj.database.models import AiModelApi, ChatModelOptions, SearchModelConfig +from khoj.database.models import AiModelApi, ChatModel, SearchModelConfig from khoj.utils.yaml import load_config_from_file, save_config_to_file logger = logging.getLogger(__name__) @@ -98,11 +98,11 @@ def migrate_server_pg(args): if "offline-chat" in raw_config["processor"]["conversation"]: offline_chat = raw_config["processor"]["conversation"]["offline-chat"] - ChatModelOptions.objects.create( - chat_model=offline_chat.get("chat-model"), + ChatModel.objects.create( + name=offline_chat.get("chat-model"), tokenizer=processor_conversation.get("tokenizer"), max_prompt_size=processor_conversation.get("max-prompt-size"), - model_type=ChatModelOptions.ModelType.OFFLINE, + model_type=ChatModel.ModelType.OFFLINE, ) if ( @@ -119,11 +119,11 @@ def migrate_server_pg(args): openai_model_api = AiModelApi.objects.create(api_key=openai.get("api-key"), name="default") - ChatModelOptions.objects.create( - chat_model=openai.get("chat-model"), + ChatModel.objects.create( + name=openai.get("chat-model"), tokenizer=processor_conversation.get("tokenizer"), max_prompt_size=processor_conversation.get("max-prompt-size"), - model_type=ChatModelOptions.ModelType.OPENAI, + model_type=ChatModel.ModelType.OPENAI, ai_model_api=openai_model_api, ) diff --git a/src/khoj/processor/conversation/anthropic/anthropic_chat.py b/src/khoj/processor/conversation/anthropic/anthropic_chat.py index fa5ff9d8..4b1a2bd8 100644 --- a/src/khoj/processor/conversation/anthropic/anthropic_chat.py +++ b/src/khoj/processor/conversation/anthropic/anthropic_chat.py @@ -5,7 +5,7 @@ from typing import Dict, List, Optional import pyjson5 from langchain.schema import ChatMessage -from khoj.database.models import Agent, ChatModelOptions, KhojUser +from khoj.database.models import Agent, ChatModel, KhojUser from khoj.processor.conversation import prompts from khoj.processor.conversation.anthropic.utils import ( anthropic_chat_completion_with_backoff, @@ -85,7 +85,7 @@ def extract_questions_anthropic( prompt = construct_structured_message( message=prompt, images=query_images, - model_type=ChatModelOptions.ModelType.ANTHROPIC, + model_type=ChatModel.ModelType.ANTHROPIC, vision_enabled=vision_enabled, attached_file_context=query_files, ) @@ -218,7 +218,7 @@ def converse_anthropic( tokenizer_name=tokenizer_name, query_images=query_images, vision_enabled=vision_available, - model_type=ChatModelOptions.ModelType.ANTHROPIC, + model_type=ChatModel.ModelType.ANTHROPIC, query_files=query_files, generated_files=generated_files, generated_asset_results=generated_asset_results, diff --git a/src/khoj/processor/conversation/google/gemini_chat.py b/src/khoj/processor/conversation/google/gemini_chat.py index 3567efed..6fd95ccd 100644 --- a/src/khoj/processor/conversation/google/gemini_chat.py +++ b/src/khoj/processor/conversation/google/gemini_chat.py @@ -5,7 +5,7 @@ from typing import Dict, List, Optional import pyjson5 from langchain.schema import ChatMessage -from khoj.database.models import Agent, ChatModelOptions, KhojUser +from khoj.database.models import Agent, ChatModel, KhojUser from khoj.processor.conversation import prompts from khoj.processor.conversation.google.utils import ( format_messages_for_gemini, @@ -86,7 +86,7 @@ def extract_questions_gemini( prompt = construct_structured_message( message=prompt, images=query_images, - model_type=ChatModelOptions.ModelType.GOOGLE, + model_type=ChatModel.ModelType.GOOGLE, vision_enabled=vision_enabled, attached_file_context=query_files, ) @@ -229,7 +229,7 @@ def converse_gemini( tokenizer_name=tokenizer_name, query_images=query_images, vision_enabled=vision_available, - model_type=ChatModelOptions.ModelType.GOOGLE, + model_type=ChatModel.ModelType.GOOGLE, query_files=query_files, generated_files=generated_files, generated_asset_results=generated_asset_results, diff --git a/src/khoj/processor/conversation/offline/chat_model.py b/src/khoj/processor/conversation/offline/chat_model.py index 2091d0a9..5ce45bac 100644 --- a/src/khoj/processor/conversation/offline/chat_model.py +++ b/src/khoj/processor/conversation/offline/chat_model.py @@ -9,7 +9,7 @@ import pyjson5 from langchain.schema import ChatMessage from llama_cpp import Llama -from khoj.database.models import Agent, ChatModelOptions, KhojUser +from khoj.database.models import Agent, ChatModel, KhojUser from khoj.processor.conversation import prompts from khoj.processor.conversation.offline.utils import download_model from khoj.processor.conversation.utils import ( @@ -96,7 +96,7 @@ def extract_questions_offline( model_name=model, loaded_model=offline_chat_model, max_prompt_size=max_prompt_size, - model_type=ChatModelOptions.ModelType.OFFLINE, + model_type=ChatModel.ModelType.OFFLINE, query_files=query_files, ) @@ -105,7 +105,7 @@ def extract_questions_offline( response = send_message_to_model_offline( messages, loaded_model=offline_chat_model, - model=model, + model_name=model, max_prompt_size=max_prompt_size, temperature=temperature, response_type="json_object", @@ -154,7 +154,7 @@ def converse_offline( online_results={}, code_results={}, conversation_log={}, - model: str = "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF", + model_name: str = "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF", loaded_model: Union[Any, None] = None, completion_func=None, conversation_commands=[ConversationCommand.Default], @@ -174,8 +174,8 @@ def converse_offline( """ # Initialize Variables assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured" - offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size) - tracer["chat_model"] = model + offline_chat_model = loaded_model or download_model(model_name, max_tokens=max_prompt_size) + tracer["chat_model"] = model_name current_date = datetime.now() if agent and agent.personality: @@ -228,18 +228,18 @@ def converse_offline( system_prompt, conversation_log, context_message=context_message, - model_name=model, + model_name=model_name, loaded_model=offline_chat_model, max_prompt_size=max_prompt_size, tokenizer_name=tokenizer_name, - model_type=ChatModelOptions.ModelType.OFFLINE, + model_type=ChatModel.ModelType.OFFLINE, query_files=query_files, generated_files=generated_files, generated_asset_results=generated_asset_results, program_execution_context=additional_context, ) - logger.debug(f"Conversation Context for {model}: {messages_to_print(messages)}") + logger.debug(f"Conversation Context for {model_name}: {messages_to_print(messages)}") g = ThreadedGenerator(references, online_results, completion_func=completion_func) t = Thread(target=llm_thread, args=(g, messages, offline_chat_model, max_prompt_size, tracer)) @@ -273,7 +273,7 @@ def llm_thread(g, messages: List[ChatMessage], model: Any, max_prompt_size: int def send_message_to_model_offline( messages: List[ChatMessage], loaded_model=None, - model="bartowski/Meta-Llama-3.1-8B-Instruct-GGUF", + model_name="bartowski/Meta-Llama-3.1-8B-Instruct-GGUF", temperature: float = 0.2, streaming=False, stop=[], @@ -282,7 +282,7 @@ def send_message_to_model_offline( tracer: dict = {}, ): assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured" - offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size) + offline_chat_model = loaded_model or download_model(model_name, max_tokens=max_prompt_size) messages_dict = [{"role": message.role, "content": message.content} for message in messages] seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None response = offline_chat_model.create_chat_completion( @@ -301,7 +301,7 @@ def send_message_to_model_offline( # Save conversation trace for non-streaming responses # Streamed responses need to be saved by the calling function - tracer["chat_model"] = model + tracer["chat_model"] = model_name tracer["temperature"] = temperature if is_promptrace_enabled(): commit_conversation_trace(messages, response_text, tracer) diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index 83e6d0df..389f52ab 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -5,7 +5,7 @@ from typing import Dict, List, Optional import pyjson5 from langchain.schema import ChatMessage -from khoj.database.models import Agent, ChatModelOptions, KhojUser +from khoj.database.models import Agent, ChatModel, KhojUser from khoj.processor.conversation import prompts from khoj.processor.conversation.openai.utils import ( chat_completion_with_backoff, @@ -83,7 +83,7 @@ def extract_questions( prompt = construct_structured_message( message=prompt, images=query_images, - model_type=ChatModelOptions.ModelType.OPENAI, + model_type=ChatModel.ModelType.OPENAI, vision_enabled=vision_enabled, attached_file_context=query_files, ) @@ -128,7 +128,7 @@ def send_message_to_model( # Get Response from GPT return completion_with_backoff( messages=messages, - model=model, + model_name=model, openai_api_key=api_key, temperature=temperature, api_base_url=api_base_url, @@ -220,7 +220,7 @@ def converse_openai( tokenizer_name=tokenizer_name, query_images=query_images, vision_enabled=vision_available, - model_type=ChatModelOptions.ModelType.OPENAI, + model_type=ChatModel.ModelType.OPENAI, query_files=query_files, generated_files=generated_files, generated_asset_results=generated_asset_results, diff --git a/src/khoj/processor/conversation/openai/utils.py b/src/khoj/processor/conversation/openai/utils.py index 160af77c..82f68259 100644 --- a/src/khoj/processor/conversation/openai/utils.py +++ b/src/khoj/processor/conversation/openai/utils.py @@ -40,7 +40,13 @@ openai_clients: Dict[str, openai.OpenAI] = {} reraise=True, ) def completion_with_backoff( - messages, model, temperature=0, openai_api_key=None, api_base_url=None, model_kwargs=None, tracer: dict = {} + messages, + model_name: str, + temperature=0, + openai_api_key=None, + api_base_url=None, + model_kwargs: dict = {}, + tracer: dict = {}, ) -> str: client_key = f"{openai_api_key}--{api_base_url}" client: openai.OpenAI | None = openai_clients.get(client_key) @@ -56,7 +62,7 @@ def completion_with_backoff( # Update request parameters for compatability with o1 model series # Refer: https://platform.openai.com/docs/guides/reasoning/beta-limitations - if model.startswith("o1"): + if model_name.startswith("o1"): temperature = 1 model_kwargs.pop("stop", None) model_kwargs.pop("response_format", None) @@ -66,12 +72,12 @@ def completion_with_backoff( chat: ChatCompletion | openai.Stream[ChatCompletionChunk] = client.chat.completions.create( messages=formatted_messages, # type: ignore - model=model, # type: ignore + model=model_name, # type: ignore stream=stream, stream_options={"include_usage": True} if stream else {}, temperature=temperature, timeout=20, - **(model_kwargs or dict()), + **model_kwargs, ) aggregated_response = "" @@ -91,10 +97,11 @@ def completion_with_backoff( # Calculate cost of chat input_tokens = chunk.usage.prompt_tokens if hasattr(chunk, "usage") and chunk.usage else 0 output_tokens = chunk.usage.completion_tokens if hasattr(chunk, "usage") and chunk.usage else 0 - tracer["usage"] = get_chat_usage_metrics(model, input_tokens, output_tokens, tracer.get("usage")) + cost = chunk.usage.model_extra.get("estimated_cost") or 0 # Estimated costs returned by DeepInfra API + tracer["usage"] = get_chat_usage_metrics(model_name, input_tokens, output_tokens, tracer.get("usage"), cost) # Save conversation trace - tracer["chat_model"] = model + tracer["chat_model"] = model_name tracer["temperature"] = temperature if is_promptrace_enabled(): commit_conversation_trace(messages, aggregated_response, tracer) @@ -139,11 +146,11 @@ def chat_completion_with_backoff( def llm_thread( g, messages, - model_name, + model_name: str, temperature, openai_api_key=None, api_base_url=None, - model_kwargs=None, + model_kwargs: dict = {}, tracer: dict = {}, ): try: @@ -177,7 +184,7 @@ def llm_thread( stream_options={"include_usage": True} if stream else {}, temperature=temperature, timeout=20, - **(model_kwargs or dict()), + **model_kwargs, ) aggregated_response = "" @@ -202,7 +209,8 @@ def llm_thread( # Calculate cost of chat input_tokens = chunk.usage.prompt_tokens if hasattr(chunk, "usage") and chunk.usage else 0 output_tokens = chunk.usage.completion_tokens if hasattr(chunk, "usage") and chunk.usage else 0 - tracer["usage"] = get_chat_usage_metrics(model_name, input_tokens, output_tokens, tracer.get("usage")) + cost = chunk.usage.model_extra.get("estimated_cost") or 0 # Estimated costs returned by DeepInfra API + tracer["usage"] = get_chat_usage_metrics(model_name, input_tokens, output_tokens, tracer.get("usage"), cost) # Save conversation trace tracer["chat_model"] = model_name diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 8de9cccf..64a46efc 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -24,7 +24,7 @@ from llama_cpp.llama import Llama from transformers import AutoTokenizer from khoj.database.adapters import ConversationAdapters -from khoj.database.models import ChatModelOptions, ClientApplication, KhojUser +from khoj.database.models import ChatModel, ClientApplication, KhojUser from khoj.processor.conversation import prompts from khoj.processor.conversation.offline.utils import download_model, infer_max_tokens from khoj.search_filter.base_filter import BaseFilter @@ -330,9 +330,9 @@ def construct_structured_message( Format messages into appropriate multimedia format for supported chat model types """ if model_type in [ - ChatModelOptions.ModelType.OPENAI, - ChatModelOptions.ModelType.GOOGLE, - ChatModelOptions.ModelType.ANTHROPIC, + ChatModel.ModelType.OPENAI, + ChatModel.ModelType.GOOGLE, + ChatModel.ModelType.ANTHROPIC, ]: if not attached_file_context and not (vision_enabled and images): return message diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 11c72eb9..3f58ca1a 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -28,12 +28,7 @@ from khoj.database.adapters import ( get_default_search_model, get_user_photo, ) -from khoj.database.models import ( - Agent, - ChatModelOptions, - KhojUser, - SpeechToTextModelOptions, -) +from khoj.database.models import Agent, ChatModel, KhojUser, SpeechToTextModelOptions from khoj.processor.conversation import prompts from khoj.processor.conversation.anthropic.anthropic_chat import ( extract_questions_anthropic, @@ -404,15 +399,15 @@ async def extract_references_and_questions( # Infer search queries from user message with timer("Extracting search queries took", logger): # If we've reached here, either the user has enabled offline chat or the openai model is enabled. - conversation_config = await ConversationAdapters.aget_default_conversation_config(user) - vision_enabled = conversation_config.vision_enabled + chat_model = await ConversationAdapters.aget_default_chat_model(user) + vision_enabled = chat_model.vision_enabled - if conversation_config.model_type == ChatModelOptions.ModelType.OFFLINE: + if chat_model.model_type == ChatModel.ModelType.OFFLINE: using_offline_chat = True - chat_model = conversation_config.chat_model - max_tokens = conversation_config.max_prompt_size + chat_model_name = chat_model.name + max_tokens = chat_model.max_prompt_size if state.offline_chat_processor_config is None: - state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens) + state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model_name, max_tokens) loaded_model = state.offline_chat_processor_config.loaded_model @@ -424,18 +419,18 @@ async def extract_references_and_questions( should_extract_questions=True, location_data=location_data, user=user, - max_prompt_size=conversation_config.max_prompt_size, + max_prompt_size=chat_model.max_prompt_size, personality_context=personality_context, query_files=query_files, tracer=tracer, ) - elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI: - api_key = conversation_config.ai_model_api.api_key - base_url = conversation_config.ai_model_api.api_base_url - chat_model = conversation_config.chat_model + elif chat_model.model_type == ChatModel.ModelType.OPENAI: + api_key = chat_model.ai_model_api.api_key + base_url = chat_model.ai_model_api.api_base_url + chat_model_name = chat_model.name inferred_queries = extract_questions( defiltered_query, - model=chat_model, + model=chat_model_name, api_key=api_key, api_base_url=base_url, conversation_log=meta_log, @@ -447,13 +442,13 @@ async def extract_references_and_questions( query_files=query_files, tracer=tracer, ) - elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC: - api_key = conversation_config.ai_model_api.api_key - chat_model = conversation_config.chat_model + elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC: + api_key = chat_model.ai_model_api.api_key + chat_model_name = chat_model.name inferred_queries = extract_questions_anthropic( defiltered_query, query_images=query_images, - model=chat_model, + model=chat_model_name, api_key=api_key, conversation_log=meta_log, location_data=location_data, @@ -463,17 +458,17 @@ async def extract_references_and_questions( query_files=query_files, tracer=tracer, ) - elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE: - api_key = conversation_config.ai_model_api.api_key - chat_model = conversation_config.chat_model + elif chat_model.model_type == ChatModel.ModelType.GOOGLE: + api_key = chat_model.ai_model_api.api_key + chat_model_name = chat_model.name inferred_queries = extract_questions_gemini( defiltered_query, query_images=query_images, - model=chat_model, + model=chat_model_name, api_key=api_key, conversation_log=meta_log, location_data=location_data, - max_tokens=conversation_config.max_prompt_size, + max_tokens=chat_model.max_prompt_size, user=user, vision_enabled=vision_enabled, personality_context=personality_context, diff --git a/src/khoj/routers/api_agents.py b/src/khoj/routers/api_agents.py index 1ab35e0a..e14a666f 100644 --- a/src/khoj/routers/api_agents.py +++ b/src/khoj/routers/api_agents.py @@ -62,7 +62,7 @@ async def all_agents( "color": agent.style_color, "icon": agent.style_icon, "privacy_level": agent.privacy_level, - "chat_model": agent.chat_model.chat_model, + "chat_model": agent.chat_model.name, "files": file_names, "input_tools": agent.input_tools, "output_modes": agent.output_modes, @@ -150,7 +150,7 @@ async def get_agent( "color": agent.style_color, "icon": agent.style_icon, "privacy_level": agent.privacy_level, - "chat_model": agent.chat_model.chat_model, + "chat_model": agent.chat_model.name, "files": file_names, "input_tools": agent.input_tools, "output_modes": agent.output_modes, @@ -225,7 +225,7 @@ async def create_agent( "color": agent.style_color, "icon": agent.style_icon, "privacy_level": agent.privacy_level, - "chat_model": agent.chat_model.chat_model, + "chat_model": agent.chat_model.name, "files": body.files, "input_tools": agent.input_tools, "output_modes": agent.output_modes, @@ -286,7 +286,7 @@ async def update_agent( "color": agent.style_color, "icon": agent.style_icon, "privacy_level": agent.privacy_level, - "chat_model": agent.chat_model.chat_model, + "chat_model": agent.chat_model.name, "files": body.files, "input_tools": agent.input_tools, "output_modes": agent.output_modes, diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index eea291d3..e170818a 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -58,7 +58,7 @@ from khoj.routers.helpers import ( is_ready_to_chat, read_chat_stream, update_telemetry_state, - validate_conversation_config, + validate_chat_model, ) from khoj.routers.research import ( InformationCollectionIteration, @@ -205,7 +205,7 @@ def chat_history( n: Optional[int] = None, ): user = request.user.object - validate_conversation_config(user) + validate_chat_model(user) # Load Conversation History conversation = ConversationAdapters.get_conversation_by_user( @@ -898,10 +898,10 @@ async def chat( custom_filters = [] if conversation_commands == [ConversationCommand.Help]: if not q: - conversation_config = await ConversationAdapters.aget_user_conversation_config(user) - if conversation_config == None: - conversation_config = await ConversationAdapters.aget_default_conversation_config(user) - model_type = conversation_config.model_type + chat_model = await ConversationAdapters.aget_user_chat_model(user) + if chat_model == None: + chat_model = await ConversationAdapters.aget_default_chat_model(user) + model_type = chat_model.model_type formatted_help = help_message.format(model=model_type, version=state.khoj_version, device=get_device()) async for result in send_llm_response(formatted_help, tracer.get("usage")): yield result diff --git a/src/khoj/routers/api_model.py b/src/khoj/routers/api_model.py index 6d6b9e21..88cb72ec 100644 --- a/src/khoj/routers/api_model.py +++ b/src/khoj/routers/api_model.py @@ -24,7 +24,7 @@ def get_chat_model_options( all_conversation_options = list() for conversation_option in conversation_options: - all_conversation_options.append({"chat_model": conversation_option.chat_model, "id": conversation_option.id}) + all_conversation_options.append({"chat_model": conversation_option.name, "id": conversation_option.id}) return Response(content=json.dumps(all_conversation_options), media_type="application/json", status_code=200) @@ -37,12 +37,12 @@ def get_user_chat_model( ): user = request.user.object - chat_model = ConversationAdapters.get_conversation_config(user) + chat_model = ConversationAdapters.get_chat_model(user) if chat_model is None: - chat_model = ConversationAdapters.get_default_conversation_config(user) + chat_model = ConversationAdapters.get_default_chat_model(user) - return Response(status_code=200, content=json.dumps({"id": chat_model.id, "chat_model": chat_model.chat_model})) + return Response(status_code=200, content=json.dumps({"id": chat_model.id, "chat_model": chat_model.name})) @api_model.post("/chat", status_code=200) diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 36a8d008..ecd1f1e4 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -56,7 +56,7 @@ from khoj.database.adapters import ( ) from khoj.database.models import ( Agent, - ChatModelOptions, + ChatModel, ClientApplication, Conversation, GithubConfig, @@ -133,40 +133,40 @@ def is_query_empty(query: str) -> bool: return is_none_or_empty(query.strip()) -def validate_conversation_config(user: KhojUser): - default_config = ConversationAdapters.get_default_conversation_config(user) +def validate_chat_model(user: KhojUser): + default_chat_model = ConversationAdapters.get_default_chat_model(user) - if default_config is None: + if default_chat_model is None: raise HTTPException(status_code=500, detail="Contact the server administrator to add a chat model.") - if default_config.model_type == "openai" and not default_config.ai_model_api: + if default_chat_model.model_type == "openai" and not default_chat_model.ai_model_api: raise HTTPException(status_code=500, detail="Contact the server administrator to add a chat model.") async def is_ready_to_chat(user: KhojUser): - user_conversation_config = await ConversationAdapters.aget_user_conversation_config(user) - if user_conversation_config == None: - user_conversation_config = await ConversationAdapters.aget_default_conversation_config(user) + user_chat_model = await ConversationAdapters.aget_user_chat_model(user) + if user_chat_model == None: + user_chat_model = await ConversationAdapters.aget_default_chat_model(user) - if user_conversation_config and user_conversation_config.model_type == ChatModelOptions.ModelType.OFFLINE: - chat_model = user_conversation_config.chat_model - max_tokens = user_conversation_config.max_prompt_size + if user_chat_model and user_chat_model.model_type == ChatModel.ModelType.OFFLINE: + chat_model_name = user_chat_model.name + max_tokens = user_chat_model.max_prompt_size if state.offline_chat_processor_config is None: logger.info("Loading Offline Chat Model...") - state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens) + state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model_name, max_tokens) return True if ( - user_conversation_config + user_chat_model and ( - user_conversation_config.model_type + user_chat_model.model_type in [ - ChatModelOptions.ModelType.OPENAI, - ChatModelOptions.ModelType.ANTHROPIC, - ChatModelOptions.ModelType.GOOGLE, + ChatModel.ModelType.OPENAI, + ChatModel.ModelType.ANTHROPIC, + ChatModel.ModelType.GOOGLE, ] ) - and user_conversation_config.ai_model_api + and user_chat_model.ai_model_api ): return True @@ -942,120 +942,124 @@ async def send_message_to_model_wrapper( query_files: str = None, tracer: dict = {}, ): - conversation_config: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config(user) - vision_available = conversation_config.vision_enabled + chat_model: ChatModel = await ConversationAdapters.aget_default_chat_model(user) + vision_available = chat_model.vision_enabled if not vision_available and query_images: - logger.warning(f"Vision is not enabled for default model: {conversation_config.chat_model}.") + logger.warning(f"Vision is not enabled for default model: {chat_model.name}.") vision_enabled_config = await ConversationAdapters.aget_vision_enabled_config() if vision_enabled_config: - conversation_config = vision_enabled_config + chat_model = vision_enabled_config vision_available = True if vision_available and query_images: - logger.info(f"Using {conversation_config.chat_model} model to understand {len(query_images)} images.") + logger.info(f"Using {chat_model.name} model to understand {len(query_images)} images.") subscribed = await ais_user_subscribed(user) - chat_model = conversation_config.chat_model + chat_model_name = chat_model.name max_tokens = ( - conversation_config.subscribed_max_prompt_size - if subscribed and conversation_config.subscribed_max_prompt_size - else conversation_config.max_prompt_size + chat_model.subscribed_max_prompt_size + if subscribed and chat_model.subscribed_max_prompt_size + else chat_model.max_prompt_size ) - tokenizer = conversation_config.tokenizer - model_type = conversation_config.model_type - vision_available = conversation_config.vision_enabled + tokenizer = chat_model.tokenizer + model_type = chat_model.model_type + vision_available = chat_model.vision_enabled - if model_type == ChatModelOptions.ModelType.OFFLINE: + if model_type == ChatModel.ModelType.OFFLINE: if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None: - state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens) + state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model_name, max_tokens) loaded_model = state.offline_chat_processor_config.loaded_model truncated_messages = generate_chatml_messages_with_context( user_message=query, context_message=context, system_message=system_message, - model_name=chat_model, + model_name=chat_model_name, loaded_model=loaded_model, tokenizer_name=tokenizer, max_prompt_size=max_tokens, vision_enabled=vision_available, - model_type=conversation_config.model_type, + model_type=chat_model.model_type, query_files=query_files, ) return send_message_to_model_offline( messages=truncated_messages, loaded_model=loaded_model, - model=chat_model, + model_name=chat_model_name, max_prompt_size=max_tokens, streaming=False, response_type=response_type, tracer=tracer, ) - elif model_type == ChatModelOptions.ModelType.OPENAI: - openai_chat_config = conversation_config.ai_model_api + elif model_type == ChatModel.ModelType.OPENAI: + openai_chat_config = chat_model.ai_model_api api_key = openai_chat_config.api_key api_base_url = openai_chat_config.api_base_url truncated_messages = generate_chatml_messages_with_context( user_message=query, context_message=context, system_message=system_message, - model_name=chat_model, + model_name=chat_model_name, max_prompt_size=max_tokens, tokenizer_name=tokenizer, vision_enabled=vision_available, query_images=query_images, - model_type=conversation_config.model_type, + model_type=chat_model.model_type, query_files=query_files, ) return send_message_to_model( messages=truncated_messages, api_key=api_key, - model=chat_model, + model=chat_model_name, response_type=response_type, api_base_url=api_base_url, tracer=tracer, ) - elif model_type == ChatModelOptions.ModelType.ANTHROPIC: - api_key = conversation_config.ai_model_api.api_key + elif model_type == ChatModel.ModelType.ANTHROPIC: + api_key = chat_model.ai_model_api.api_key truncated_messages = generate_chatml_messages_with_context( user_message=query, context_message=context, system_message=system_message, - model_name=chat_model, + model_name=chat_model_name, max_prompt_size=max_tokens, tokenizer_name=tokenizer, vision_enabled=vision_available, query_images=query_images, - model_type=conversation_config.model_type, + model_type=chat_model.model_type, query_files=query_files, ) return anthropic_send_message_to_model( messages=truncated_messages, api_key=api_key, - model=chat_model, + model=chat_model_name, response_type=response_type, tracer=tracer, ) - elif model_type == ChatModelOptions.ModelType.GOOGLE: - api_key = conversation_config.ai_model_api.api_key + elif model_type == ChatModel.ModelType.GOOGLE: + api_key = chat_model.ai_model_api.api_key truncated_messages = generate_chatml_messages_with_context( user_message=query, context_message=context, system_message=system_message, - model_name=chat_model, + model_name=chat_model_name, max_prompt_size=max_tokens, tokenizer_name=tokenizer, vision_enabled=vision_available, query_images=query_images, - model_type=conversation_config.model_type, + model_type=chat_model.model_type, query_files=query_files, ) return gemini_send_message_to_model( - messages=truncated_messages, api_key=api_key, model=chat_model, response_type=response_type, tracer=tracer + messages=truncated_messages, + api_key=api_key, + model=chat_model_name, + response_type=response_type, + tracer=tracer, ) else: raise HTTPException(status_code=500, detail="Invalid conversation config") @@ -1069,99 +1073,99 @@ def send_message_to_model_wrapper_sync( query_files: str = "", tracer: dict = {}, ): - conversation_config: ChatModelOptions = ConversationAdapters.get_default_conversation_config(user) + chat_model: ChatModel = ConversationAdapters.get_default_chat_model(user) - if conversation_config is None: + if chat_model is None: raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.") - chat_model = conversation_config.chat_model - max_tokens = conversation_config.max_prompt_size - vision_available = conversation_config.vision_enabled + chat_model_name = chat_model.name + max_tokens = chat_model.max_prompt_size + vision_available = chat_model.vision_enabled - if conversation_config.model_type == ChatModelOptions.ModelType.OFFLINE: + if chat_model.model_type == ChatModel.ModelType.OFFLINE: if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None: - state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens) + state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model_name, max_tokens) loaded_model = state.offline_chat_processor_config.loaded_model truncated_messages = generate_chatml_messages_with_context( user_message=message, system_message=system_message, - model_name=chat_model, + model_name=chat_model_name, loaded_model=loaded_model, max_prompt_size=max_tokens, vision_enabled=vision_available, - model_type=conversation_config.model_type, + model_type=chat_model.model_type, query_files=query_files, ) return send_message_to_model_offline( messages=truncated_messages, loaded_model=loaded_model, - model=chat_model, + model_name=chat_model_name, max_prompt_size=max_tokens, streaming=False, response_type=response_type, tracer=tracer, ) - elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI: - api_key = conversation_config.ai_model_api.api_key + elif chat_model.model_type == ChatModel.ModelType.OPENAI: + api_key = chat_model.ai_model_api.api_key truncated_messages = generate_chatml_messages_with_context( user_message=message, system_message=system_message, - model_name=chat_model, + model_name=chat_model_name, max_prompt_size=max_tokens, vision_enabled=vision_available, - model_type=conversation_config.model_type, + model_type=chat_model.model_type, query_files=query_files, ) openai_response = send_message_to_model( messages=truncated_messages, api_key=api_key, - model=chat_model, + model=chat_model_name, response_type=response_type, tracer=tracer, ) return openai_response - elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC: - api_key = conversation_config.ai_model_api.api_key + elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC: + api_key = chat_model.ai_model_api.api_key truncated_messages = generate_chatml_messages_with_context( user_message=message, system_message=system_message, - model_name=chat_model, + model_name=chat_model_name, max_prompt_size=max_tokens, vision_enabled=vision_available, - model_type=conversation_config.model_type, + model_type=chat_model.model_type, query_files=query_files, ) return anthropic_send_message_to_model( messages=truncated_messages, api_key=api_key, - model=chat_model, + model=chat_model_name, response_type=response_type, tracer=tracer, ) - elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE: - api_key = conversation_config.ai_model_api.api_key + elif chat_model.model_type == ChatModel.ModelType.GOOGLE: + api_key = chat_model.ai_model_api.api_key truncated_messages = generate_chatml_messages_with_context( user_message=message, system_message=system_message, - model_name=chat_model, + model_name=chat_model_name, max_prompt_size=max_tokens, vision_enabled=vision_available, - model_type=conversation_config.model_type, + model_type=chat_model.model_type, query_files=query_files, ) return gemini_send_message_to_model( messages=truncated_messages, api_key=api_key, - model=chat_model, + model=chat_model_name, response_type=response_type, tracer=tracer, ) @@ -1229,15 +1233,15 @@ def generate_chat_response( online_results = {} code_results = {} - conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation) - vision_available = conversation_config.vision_enabled + chat_model = ConversationAdapters.get_valid_chat_model(user, conversation) + vision_available = chat_model.vision_enabled if not vision_available and query_images: vision_enabled_config = ConversationAdapters.get_vision_enabled_config() if vision_enabled_config: - conversation_config = vision_enabled_config + chat_model = vision_enabled_config vision_available = True - if conversation_config.model_type == "offline": + if chat_model.model_type == "offline": loaded_model = state.offline_chat_processor_config.loaded_model chat_response = converse_offline( user_query=query_to_run, @@ -1247,9 +1251,9 @@ def generate_chat_response( conversation_log=meta_log, completion_func=partial_completion, conversation_commands=conversation_commands, - model=conversation_config.chat_model, - max_prompt_size=conversation_config.max_prompt_size, - tokenizer_name=conversation_config.tokenizer, + model_name=chat_model.name, + max_prompt_size=chat_model.max_prompt_size, + tokenizer_name=chat_model.tokenizer, location_data=location_data, user_name=user_name, agent=agent, @@ -1259,10 +1263,10 @@ def generate_chat_response( tracer=tracer, ) - elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI: - openai_chat_config = conversation_config.ai_model_api + elif chat_model.model_type == ChatModel.ModelType.OPENAI: + openai_chat_config = chat_model.ai_model_api api_key = openai_chat_config.api_key - chat_model = conversation_config.chat_model + chat_model_name = chat_model.name chat_response = converse_openai( compiled_references, query_to_run, @@ -1270,13 +1274,13 @@ def generate_chat_response( online_results=online_results, code_results=code_results, conversation_log=meta_log, - model=chat_model, + model=chat_model_name, api_key=api_key, api_base_url=openai_chat_config.api_base_url, completion_func=partial_completion, conversation_commands=conversation_commands, - max_prompt_size=conversation_config.max_prompt_size, - tokenizer_name=conversation_config.tokenizer, + max_prompt_size=chat_model.max_prompt_size, + tokenizer_name=chat_model.tokenizer, location_data=location_data, user_name=user_name, agent=agent, @@ -1288,8 +1292,8 @@ def generate_chat_response( tracer=tracer, ) - elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC: - api_key = conversation_config.ai_model_api.api_key + elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC: + api_key = chat_model.ai_model_api.api_key chat_response = converse_anthropic( compiled_references, query_to_run, @@ -1297,12 +1301,12 @@ def generate_chat_response( online_results=online_results, code_results=code_results, conversation_log=meta_log, - model=conversation_config.chat_model, + model=chat_model.name, api_key=api_key, completion_func=partial_completion, conversation_commands=conversation_commands, - max_prompt_size=conversation_config.max_prompt_size, - tokenizer_name=conversation_config.tokenizer, + max_prompt_size=chat_model.max_prompt_size, + tokenizer_name=chat_model.tokenizer, location_data=location_data, user_name=user_name, agent=agent, @@ -1313,20 +1317,20 @@ def generate_chat_response( program_execution_context=program_execution_context, tracer=tracer, ) - elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE: - api_key = conversation_config.ai_model_api.api_key + elif chat_model.model_type == ChatModel.ModelType.GOOGLE: + api_key = chat_model.ai_model_api.api_key chat_response = converse_gemini( compiled_references, query_to_run, online_results, code_results, meta_log, - model=conversation_config.chat_model, + model=chat_model.name, api_key=api_key, completion_func=partial_completion, conversation_commands=conversation_commands, - max_prompt_size=conversation_config.max_prompt_size, - tokenizer_name=conversation_config.tokenizer, + max_prompt_size=chat_model.max_prompt_size, + tokenizer_name=chat_model.tokenizer, location_data=location_data, user_name=user_name, agent=agent, @@ -1339,7 +1343,7 @@ def generate_chat_response( tracer=tracer, ) - metadata.update({"chat_model": conversation_config.chat_model}) + metadata.update({"chat_model": chat_model.name}) except Exception as e: logger.error(e, exc_info=True) @@ -1939,13 +1943,13 @@ def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False) current_notion_config = get_user_notion_config(user) notion_token = current_notion_config.token if current_notion_config else "" - selected_chat_model_config = ConversationAdapters.get_conversation_config( + selected_chat_model_config = ConversationAdapters.get_chat_model( user - ) or ConversationAdapters.get_default_conversation_config(user) + ) or ConversationAdapters.get_default_chat_model(user) chat_models = ConversationAdapters.get_conversation_processor_options().all() chat_model_options = list() for chat_model in chat_models: - chat_model_options.append({"name": chat_model.chat_model, "id": chat_model.id}) + chat_model_options.append({"name": chat_model.name, "id": chat_model.id}) selected_paint_model_config = ConversationAdapters.get_user_text_to_image_model_config(user) paint_model_options = ConversationAdapters.get_text_to_image_model_options().all() diff --git a/src/khoj/utils/helpers.py b/src/khoj/utils/helpers.py index 6fafa6d9..6214e5f5 100644 --- a/src/khoj/utils/helpers.py +++ b/src/khoj/utils/helpers.py @@ -584,13 +584,15 @@ def get_cost_of_chat_message(model_name: str, input_tokens: int = 0, output_toke return input_cost + output_cost + prev_cost -def get_chat_usage_metrics(model_name: str, input_tokens: int = 0, output_tokens: int = 0, usage: dict = {}): +def get_chat_usage_metrics( + model_name: str, input_tokens: int = 0, output_tokens: int = 0, usage: dict = {}, cost: float = None +): """ - Get usage metrics for chat message based on input and output tokens + Get usage metrics for chat message based on input and output tokens and cost """ prev_usage = usage or {"input_tokens": 0, "output_tokens": 0, "cost": 0.0} return { "input_tokens": prev_usage["input_tokens"] + input_tokens, "output_tokens": prev_usage["output_tokens"] + output_tokens, - "cost": get_cost_of_chat_message(model_name, input_tokens, output_tokens, prev_cost=prev_usage["cost"]), + "cost": cost or get_cost_of_chat_message(model_name, input_tokens, output_tokens, prev_cost=prev_usage["cost"]), } diff --git a/src/khoj/utils/initialization.py b/src/khoj/utils/initialization.py index 91c6ce1a..a4864dcc 100644 --- a/src/khoj/utils/initialization.py +++ b/src/khoj/utils/initialization.py @@ -7,7 +7,7 @@ import openai from khoj.database.adapters import ConversationAdapters from khoj.database.models import ( AiModelApi, - ChatModelOptions, + ChatModel, KhojUser, SpeechToTextModelOptions, TextToImageModelConfig, @@ -63,7 +63,7 @@ def initialization(interactive: bool = True): # Set up OpenAI's online chat models openai_configured, openai_provider = _setup_chat_model_provider( - ChatModelOptions.ModelType.OPENAI, + ChatModel.ModelType.OPENAI, default_chat_models, default_api_key=openai_api_key, api_base_url=openai_api_base, @@ -105,7 +105,7 @@ def initialization(interactive: bool = True): # Set up Google's Gemini online chat models _setup_chat_model_provider( - ChatModelOptions.ModelType.GOOGLE, + ChatModel.ModelType.GOOGLE, default_gemini_chat_models, default_api_key=os.getenv("GEMINI_API_KEY"), vision_enabled=True, @@ -116,7 +116,7 @@ def initialization(interactive: bool = True): # Set up Anthropic's online chat models _setup_chat_model_provider( - ChatModelOptions.ModelType.ANTHROPIC, + ChatModel.ModelType.ANTHROPIC, default_anthropic_chat_models, default_api_key=os.getenv("ANTHROPIC_API_KEY"), vision_enabled=True, @@ -126,7 +126,7 @@ def initialization(interactive: bool = True): # Set up offline chat models _setup_chat_model_provider( - ChatModelOptions.ModelType.OFFLINE, + ChatModel.ModelType.OFFLINE, default_offline_chat_models, default_api_key=None, vision_enabled=False, @@ -135,9 +135,9 @@ def initialization(interactive: bool = True): ) # Explicitly set default chat model - chat_models_configured = ChatModelOptions.objects.count() + chat_models_configured = ChatModel.objects.count() if chat_models_configured > 0: - default_chat_model_name = ChatModelOptions.objects.first().chat_model + default_chat_model_name = ChatModel.objects.first().name # If there are multiple chat models, ask the user to choose the default chat model if chat_models_configured > 1 and interactive: user_chat_model_name = input( @@ -147,7 +147,7 @@ def initialization(interactive: bool = True): user_chat_model_name = None # If the user's choice is valid, set it as the default chat model - if user_chat_model_name and ChatModelOptions.objects.filter(chat_model=user_chat_model_name).exists(): + if user_chat_model_name and ChatModel.objects.filter(name=user_chat_model_name).exists(): default_chat_model_name = user_chat_model_name logger.info("🗣️ Chat model configuration complete") @@ -171,7 +171,7 @@ def initialization(interactive: bool = True): logger.info(f"🗣️ Offline speech to text model configured to {offline_speech2text_model}") def _setup_chat_model_provider( - model_type: ChatModelOptions.ModelType, + model_type: ChatModel.ModelType, default_chat_models: list, default_api_key: str, interactive: bool, @@ -204,10 +204,10 @@ def initialization(interactive: bool = True): ai_model_api = AiModelApi.objects.create(api_key=api_key, name=provider_name, api_base_url=api_base_url) if interactive: - chat_model_names = input( + user_chat_models = input( f"Enter the {provider_name} chat models you want to use (default: {','.join(default_chat_models)}): " ) - chat_models = chat_model_names.split(",") if chat_model_names != "" else default_chat_models + chat_models = user_chat_models.split(",") if user_chat_models != "" else default_chat_models chat_models = [model.strip() for model in chat_models] else: chat_models = default_chat_models @@ -218,7 +218,7 @@ def initialization(interactive: bool = True): vision_enabled = vision_enabled and chat_model in supported_vision_models chat_model_options = { - "chat_model": chat_model, + "name": chat_model, "model_type": model_type, "max_prompt_size": default_max_tokens, "vision_enabled": vision_enabled, @@ -226,7 +226,7 @@ def initialization(interactive: bool = True): "ai_model_api": ai_model_api, } - ChatModelOptions.objects.create(**chat_model_options) + ChatModel.objects.create(**chat_model_options) logger.info(f"🗣️ {provider_name} chat model configuration complete") return True, ai_model_api @@ -250,19 +250,19 @@ def initialization(interactive: bool = True): available_models = [model.id for model in openai_client.models.list()] # Get existing chat model options for this config - existing_models = ChatModelOptions.objects.filter( - ai_model_api=config, model_type=ChatModelOptions.ModelType.OPENAI + existing_models = ChatModel.objects.filter( + ai_model_api=config, model_type=ChatModel.ModelType.OPENAI ) # Add new models - for model in available_models: - if not existing_models.filter(chat_model=model).exists(): - ChatModelOptions.objects.create( - chat_model=model, - model_type=ChatModelOptions.ModelType.OPENAI, - max_prompt_size=model_to_prompt_size.get(model), - vision_enabled=model in default_openai_chat_models, - tokenizer=model_to_tokenizer.get(model), + for model_name in available_models: + if not existing_models.filter(name=model_name).exists(): + ChatModel.objects.create( + name=model_name, + model_type=ChatModel.ModelType.OPENAI, + max_prompt_size=model_to_prompt_size.get(model_name), + vision_enabled=model_name in default_openai_chat_models, + tokenizer=model_to_tokenizer.get(model_name), ai_model_api=config, ) @@ -284,7 +284,7 @@ def initialization(interactive: bool = True): except Exception as e: logger.error(f"🚨 Failed to create admin user: {e}", exc_info=True) - chat_config = ConversationAdapters.get_default_conversation_config() + chat_config = ConversationAdapters.get_default_chat_model() if admin_user is None and chat_config is None: while True: try: diff --git a/tests/conftest.py b/tests/conftest.py index 7561d901..1795b340 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -13,7 +13,7 @@ from khoj.configure import ( ) from khoj.database.models import ( Agent, - ChatModelOptions, + ChatModel, GithubConfig, GithubRepoConfig, KhojApiUser, @@ -35,7 +35,7 @@ from khoj.utils.helpers import resolve_absolute_path from khoj.utils.rawconfig import ContentConfig, ImageSearchConfig, SearchConfig from tests.helpers import ( AiModelApiFactory, - ChatModelOptionsFactory, + ChatModelFactory, ProcessLockFactory, SubscriptionFactory, UserConversationProcessorConfigFactory, @@ -184,14 +184,14 @@ def api_user4(default_user4): @pytest.mark.django_db @pytest.fixture def default_openai_chat_model_option(): - chat_model = ChatModelOptionsFactory(chat_model="gpt-4o-mini", model_type="openai") + chat_model = ChatModelFactory(name="gpt-4o-mini", model_type="openai") return chat_model @pytest.mark.django_db @pytest.fixture def offline_agent(): - chat_model = ChatModelOptionsFactory() + chat_model = ChatModelFactory() return Agent.objects.create( name="Accountant", chat_model=chat_model, @@ -202,7 +202,7 @@ def offline_agent(): @pytest.mark.django_db @pytest.fixture def openai_agent(): - chat_model = ChatModelOptionsFactory(chat_model="gpt-4o-mini", model_type="openai") + chat_model = ChatModelFactory(name="gpt-4o-mini", model_type="openai") return Agent.objects.create( name="Accountant", chat_model=chat_model, @@ -311,13 +311,13 @@ def chat_client_builder(search_config, user, index_content=True, require_auth=Fa # Initialize Processor from Config chat_provider = get_chat_provider() - online_chat_model: ChatModelOptionsFactory = None - if chat_provider == ChatModelOptions.ModelType.OPENAI: - online_chat_model = ChatModelOptionsFactory(chat_model="gpt-4o-mini", model_type="openai") - elif chat_provider == ChatModelOptions.ModelType.GOOGLE: - online_chat_model = ChatModelOptionsFactory(chat_model="gemini-1.5-flash", model_type="google") - elif chat_provider == ChatModelOptions.ModelType.ANTHROPIC: - online_chat_model = ChatModelOptionsFactory(chat_model="claude-3-5-haiku-20241022", model_type="anthropic") + online_chat_model: ChatModelFactory = None + if chat_provider == ChatModel.ModelType.OPENAI: + online_chat_model = ChatModelFactory(name="gpt-4o-mini", model_type="openai") + elif chat_provider == ChatModel.ModelType.GOOGLE: + online_chat_model = ChatModelFactory(name="gemini-1.5-flash", model_type="google") + elif chat_provider == ChatModel.ModelType.ANTHROPIC: + online_chat_model = ChatModelFactory(name="claude-3-5-haiku-20241022", model_type="anthropic") if online_chat_model: online_chat_model.ai_model_api = AiModelApiFactory(api_key=get_chat_api_key(chat_provider)) UserConversationProcessorConfigFactory(user=user, setting=online_chat_model) @@ -394,8 +394,8 @@ def client_offline_chat(search_config: SearchConfig, default_user2: KhojUser): configure_content(default_user2, all_files) # Initialize Processor from Config - ChatModelOptionsFactory( - chat_model="bartowski/Meta-Llama-3.1-3B-Instruct-GGUF", + ChatModelFactory( + name="bartowski/Meta-Llama-3.1-3B-Instruct-GGUF", tokenizer=None, max_prompt_size=None, model_type="offline", diff --git a/tests/helpers.py b/tests/helpers.py index 04ed6df5..b2c6a3b1 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -6,7 +6,7 @@ from django.utils.timezone import make_aware from khoj.database.models import ( AiModelApi, - ChatModelOptions, + ChatModel, Conversation, KhojApiUser, KhojUser, @@ -18,27 +18,27 @@ from khoj.database.models import ( from khoj.processor.conversation.utils import message_to_log -def get_chat_provider(default: ChatModelOptions.ModelType | None = ChatModelOptions.ModelType.OFFLINE): +def get_chat_provider(default: ChatModel.ModelType | None = ChatModel.ModelType.OFFLINE): provider = os.getenv("KHOJ_TEST_CHAT_PROVIDER") - if provider and provider in ChatModelOptions.ModelType: - return ChatModelOptions.ModelType(provider) + if provider and provider in ChatModel.ModelType: + return ChatModel.ModelType(provider) elif os.getenv("OPENAI_API_KEY"): - return ChatModelOptions.ModelType.OPENAI + return ChatModel.ModelType.OPENAI elif os.getenv("GEMINI_API_KEY"): - return ChatModelOptions.ModelType.GOOGLE + return ChatModel.ModelType.GOOGLE elif os.getenv("ANTHROPIC_API_KEY"): - return ChatModelOptions.ModelType.ANTHROPIC + return ChatModel.ModelType.ANTHROPIC else: return default -def get_chat_api_key(provider: ChatModelOptions.ModelType = None): +def get_chat_api_key(provider: ChatModel.ModelType = None): provider = provider or get_chat_provider() - if provider == ChatModelOptions.ModelType.OPENAI: + if provider == ChatModel.ModelType.OPENAI: return os.getenv("OPENAI_API_KEY") - elif provider == ChatModelOptions.ModelType.GOOGLE: + elif provider == ChatModel.ModelType.GOOGLE: return os.getenv("GEMINI_API_KEY") - elif provider == ChatModelOptions.ModelType.ANTHROPIC: + elif provider == ChatModel.ModelType.ANTHROPIC: return os.getenv("ANTHROPIC_API_KEY") else: return os.getenv("OPENAI_API_KEY") or os.getenv("GEMINI_API_KEY") or os.getenv("ANTHROPIC_API_KEY") @@ -83,13 +83,13 @@ class AiModelApiFactory(factory.django.DjangoModelFactory): api_key = get_chat_api_key() -class ChatModelOptionsFactory(factory.django.DjangoModelFactory): +class ChatModelFactory(factory.django.DjangoModelFactory): class Meta: - model = ChatModelOptions + model = ChatModel max_prompt_size = 20000 tokenizer = None - chat_model = "bartowski/Meta-Llama-3.2-3B-Instruct-GGUF" + name = "bartowski/Meta-Llama-3.2-3B-Instruct-GGUF" model_type = get_chat_provider() ai_model_api = factory.LazyAttribute(lambda obj: AiModelApiFactory() if get_chat_api_key() else None) @@ -99,7 +99,7 @@ class UserConversationProcessorConfigFactory(factory.django.DjangoModelFactory): model = UserConversationConfig user = factory.SubFactory(UserFactory) - setting = factory.SubFactory(ChatModelOptionsFactory) + setting = factory.SubFactory(ChatModelFactory) class ConversationFactory(factory.django.DjangoModelFactory): diff --git a/tests/test_agents.py b/tests/test_agents.py index da0b2357..242495e6 100644 --- a/tests/test_agents.py +++ b/tests/test_agents.py @@ -5,14 +5,14 @@ import pytest from asgiref.sync import sync_to_async from khoj.database.adapters import AgentAdapters -from khoj.database.models import Agent, ChatModelOptions, Entry, KhojUser +from khoj.database.models import Agent, ChatModel, Entry, KhojUser from khoj.routers.api import execute_search from khoj.utils.helpers import get_absolute_path -from tests.helpers import ChatModelOptionsFactory +from tests.helpers import ChatModelFactory def test_create_default_agent(default_user: KhojUser): - ChatModelOptionsFactory() + ChatModelFactory() agent = AgentAdapters.create_default_agent(default_user) assert agent is not None @@ -24,7 +24,7 @@ def test_create_default_agent(default_user: KhojUser): @pytest.mark.anyio @pytest.mark.django_db(transaction=True) -async def test_create_or_update_agent(default_user: KhojUser, default_openai_chat_model_option: ChatModelOptions): +async def test_create_or_update_agent(default_user: KhojUser, default_openai_chat_model_option: ChatModel): new_agent = await AgentAdapters.aupdate_agent( default_user, "Test Agent", @@ -32,7 +32,7 @@ async def test_create_or_update_agent(default_user: KhojUser, default_openai_cha Agent.PrivacyLevel.PRIVATE, "icon", "color", - default_openai_chat_model_option.chat_model, + default_openai_chat_model_option.name, [], [], [], @@ -46,7 +46,7 @@ async def test_create_or_update_agent(default_user: KhojUser, default_openai_cha @pytest.mark.anyio @pytest.mark.django_db(transaction=True) async def test_create_or_update_agent_with_knowledge_base( - default_user2: KhojUser, default_openai_chat_model_option: ChatModelOptions, chat_client + default_user2: KhojUser, default_openai_chat_model_option: ChatModel, chat_client ): full_filename = get_absolute_path("tests/data/markdown/having_kids.markdown") new_agent = await AgentAdapters.aupdate_agent( @@ -56,7 +56,7 @@ async def test_create_or_update_agent_with_knowledge_base( Agent.PrivacyLevel.PRIVATE, "icon", "color", - default_openai_chat_model_option.chat_model, + default_openai_chat_model_option.name, [full_filename], [], [], @@ -78,7 +78,7 @@ async def test_create_or_update_agent_with_knowledge_base( @pytest.mark.anyio @pytest.mark.django_db(transaction=True) async def test_create_or_update_agent_with_knowledge_base_and_search( - default_user2: KhojUser, default_openai_chat_model_option: ChatModelOptions, chat_client + default_user2: KhojUser, default_openai_chat_model_option: ChatModel, chat_client ): full_filename = get_absolute_path("tests/data/markdown/having_kids.markdown") new_agent = await AgentAdapters.aupdate_agent( @@ -88,7 +88,7 @@ async def test_create_or_update_agent_with_knowledge_base_and_search( Agent.PrivacyLevel.PRIVATE, "icon", "color", - default_openai_chat_model_option.chat_model, + default_openai_chat_model_option.name, [full_filename], [], [], @@ -102,7 +102,7 @@ async def test_create_or_update_agent_with_knowledge_base_and_search( @pytest.mark.anyio @pytest.mark.django_db(transaction=True) async def test_agent_with_knowledge_base_and_search_not_creator( - default_user2: KhojUser, default_openai_chat_model_option: ChatModelOptions, chat_client, default_user3: KhojUser + default_user2: KhojUser, default_openai_chat_model_option: ChatModel, chat_client, default_user3: KhojUser ): full_filename = get_absolute_path("tests/data/markdown/having_kids.markdown") new_agent = await AgentAdapters.aupdate_agent( @@ -112,7 +112,7 @@ async def test_agent_with_knowledge_base_and_search_not_creator( Agent.PrivacyLevel.PUBLIC, "icon", "color", - default_openai_chat_model_option.chat_model, + default_openai_chat_model_option.name, [full_filename], [], [], @@ -126,7 +126,7 @@ async def test_agent_with_knowledge_base_and_search_not_creator( @pytest.mark.anyio @pytest.mark.django_db(transaction=True) async def test_agent_with_knowledge_base_and_search_not_creator_and_private( - default_user2: KhojUser, default_openai_chat_model_option: ChatModelOptions, chat_client, default_user3: KhojUser + default_user2: KhojUser, default_openai_chat_model_option: ChatModel, chat_client, default_user3: KhojUser ): full_filename = get_absolute_path("tests/data/markdown/having_kids.markdown") new_agent = await AgentAdapters.aupdate_agent( @@ -136,7 +136,7 @@ async def test_agent_with_knowledge_base_and_search_not_creator_and_private( Agent.PrivacyLevel.PRIVATE, "icon", "color", - default_openai_chat_model_option.chat_model, + default_openai_chat_model_option.name, [full_filename], [], [], @@ -150,7 +150,7 @@ async def test_agent_with_knowledge_base_and_search_not_creator_and_private( @pytest.mark.anyio @pytest.mark.django_db(transaction=True) async def test_agent_with_knowledge_base_and_search_not_creator_and_private_accessible_to_none( - default_user2: KhojUser, default_openai_chat_model_option: ChatModelOptions, chat_client + default_user2: KhojUser, default_openai_chat_model_option: ChatModel, chat_client ): full_filename = get_absolute_path("tests/data/markdown/having_kids.markdown") new_agent = await AgentAdapters.aupdate_agent( @@ -160,7 +160,7 @@ async def test_agent_with_knowledge_base_and_search_not_creator_and_private_acce Agent.PrivacyLevel.PRIVATE, "icon", "color", - default_openai_chat_model_option.chat_model, + default_openai_chat_model_option.name, [full_filename], [], [], @@ -174,7 +174,7 @@ async def test_agent_with_knowledge_base_and_search_not_creator_and_private_acce @pytest.mark.anyio @pytest.mark.django_db(transaction=True) async def test_multiple_agents_with_knowledge_base_and_users( - default_user2: KhojUser, default_openai_chat_model_option: ChatModelOptions, chat_client, default_user3: KhojUser + default_user2: KhojUser, default_openai_chat_model_option: ChatModel, chat_client, default_user3: KhojUser ): full_filename = get_absolute_path("tests/data/markdown/having_kids.markdown") new_agent = await AgentAdapters.aupdate_agent( @@ -184,7 +184,7 @@ async def test_multiple_agents_with_knowledge_base_and_users( Agent.PrivacyLevel.PUBLIC, "icon", "color", - default_openai_chat_model_option.chat_model, + default_openai_chat_model_option.name, [full_filename], [], [], @@ -198,7 +198,7 @@ async def test_multiple_agents_with_knowledge_base_and_users( Agent.PrivacyLevel.PUBLIC, "icon", "color", - default_openai_chat_model_option.chat_model, + default_openai_chat_model_option.name, [full_filename2], [], [], diff --git a/tests/test_offline_chat_actors.py b/tests/test_offline_chat_actors.py index e84612a2..6f18f658 100644 --- a/tests/test_offline_chat_actors.py +++ b/tests/test_offline_chat_actors.py @@ -2,12 +2,12 @@ from datetime import datetime import pytest -from khoj.database.models import ChatModelOptions +from khoj.database.models import ChatModel from khoj.routers.helpers import aget_data_sources_and_output_format from khoj.utils.helpers import ConversationCommand from tests.helpers import ConversationFactory, generate_chat_history, get_chat_provider -SKIP_TESTS = get_chat_provider(default=None) != ChatModelOptions.ModelType.OFFLINE +SKIP_TESTS = get_chat_provider(default=None) != ChatModel.ModelType.OFFLINE pytestmark = pytest.mark.skipif( SKIP_TESTS, reason="Disable in CI to avoid long test runs.", diff --git a/tests/test_offline_chat_director.py b/tests/test_offline_chat_director.py index 5caa1ca0..b681d38c 100644 --- a/tests/test_offline_chat_director.py +++ b/tests/test_offline_chat_director.py @@ -4,12 +4,12 @@ import pytest from faker import Faker from freezegun import freeze_time -from khoj.database.models import Agent, ChatModelOptions, Entry, KhojUser +from khoj.database.models import Agent, ChatModel, Entry, KhojUser from khoj.processor.conversation import prompts from khoj.processor.conversation.utils import message_to_log from tests.helpers import ConversationFactory, get_chat_provider -SKIP_TESTS = get_chat_provider(default=None) != ChatModelOptions.ModelType.OFFLINE +SKIP_TESTS = get_chat_provider(default=None) != ChatModel.ModelType.OFFLINE pytestmark = pytest.mark.skipif( SKIP_TESTS, reason="Disable in CI to avoid long test runs.",