Merge branch 'master' of github.com:khoj-ai/khoj into features/new-sign-in-page

This commit is contained in:
sabaimran 2024-12-12 15:43:06 -08:00
commit dfc150c442
30 changed files with 412 additions and 340 deletions

View file

@ -113,7 +113,8 @@ jobs:
khoj --anonymous-mode --non-interactive &
# Start code sandbox
npm run dev --prefix terrarium &
npm install -g pm2
npm run ci --prefix terrarium
# Wait for server to be ready
timeout=120

View file

@ -25,7 +25,7 @@ Using LiteLLM with Khoj makes it possible to turn any LLM behind an API into you
- Name: `proxy-name`
- Api Key: `any string`
- Api Base Url: **URL of your Openai Proxy API**
4. Create a new [Chat Model Option](http://localhost:42110/server/admin/database/chatmodeloptions/add) on your Khoj admin panel.
4. Create a new [Chat Model](http://localhost:42110/server/admin/database/chatmodel/add) on your Khoj admin panel.
- Name: `llama3.1` (replace with the name of your local model)
- Model Type: `Openai`
- Openai Config: `<the proxy config you created in step 3>`

View file

@ -18,7 +18,7 @@ LM Studio can expose an [OpenAI API compatible server](https://lmstudio.ai/docs/
- Name: `proxy-name`
- Api Key: `any string`
- Api Base Url: `http://localhost:1234/v1/` (default for LMStudio)
4. Create a new [Chat Model Option](http://localhost:42110/server/admin/database/chatmodeloptions/add) on your Khoj admin panel.
4. Create a new [Chat Model](http://localhost:42110/server/admin/database/chatmodel/add) on your Khoj admin panel.
- Name: `llama3.1` (replace with the name of your local model)
- Model Type: `Openai`
- Openai Config: `<the proxy config you created in step 3>`

View file

@ -64,7 +64,7 @@ Restart your Khoj server after first run or update to the settings below to ensu
- Name: `ollama`
- Api Key: `any string`
- Api Base Url: `http://localhost:11434/v1/` (default for Ollama)
4. Create a new [Chat Model Option](http://localhost:42110/server/admin/database/chatmodeloptions/add) on your Khoj admin panel.
4. Create a new [Chat Model](http://localhost:42110/server/admin/database/chatmodel/add) on your Khoj admin panel.
- Name: `llama3.1` (replace with the name of your local model)
- Model Type: `Openai`
- Openai Config: `<the ollama config you created in step 3>`

View file

@ -25,7 +25,7 @@ For specific integrations, see our [Ollama](/advanced/ollama), [LMStudio](/advan
- Name: `any name`
- Api Key: `any string`
- Api Base Url: **URL of your Openai Proxy API**
3. Create a new [Chat Model Option](http://localhost:42110/server/admin/database/chatmodeloptions/add) on your Khoj admin panel.
3. Create a new [Chat Model](http://localhost:42110/server/admin/database/chatmodel/add) on your Khoj admin panel.
- Name: `llama3` (replace with the name of your local model)
- Model Type: `Openai`
- Openai Config: `<the proxy config you created in step 2>`

View file

@ -17,7 +17,7 @@ You have a couple of image generation options.
We support most state of the art image generation models, including Ideogram, Flux, and Stable Diffusion. These will run using [Replicate](https://replicate.com). Here's how to set them up:
1. Get a Replicate API key [here](https://replicate.com/account/api-tokens).
1. Create a new [Text to Image Model](https://app.khoj.dev/server/admin/database/texttoimagemodelconfig/). Set the `type` to `Replicate`. Use any of the model names you see [on this list](https://replicate.com/pricing#image-models).
1. Create a new [Text to Image Model](http://localhost:42110/server/admin/database/texttoimagemodelconfig/). Set the `type` to `Replicate`. Use any of the model names you see [on this list](https://replicate.com/pricing#image-models).
### OpenAI

View file

@ -307,7 +307,7 @@ Using Ollama? See the [Ollama Integration](/advanced/ollama) section for more cu
- Give the configuration a friendly name like `OpenAI`
- (Optional) Set the API base URL. It is only relevant if you're using another OpenAI-compatible proxy server like [Ollama](/advanced/ollama) or [LMStudio](/advanced/lmstudio).<br />
![example configuration for ai model api](/img/example_openai_processor_config.png)
2. Create a new [chat model options](http://localhost:42110/server/admin/database/chatmodeloptions/add)
2. Create a new [chat model](http://localhost:42110/server/admin/database/chatmodel/add)
- Set the `chat-model` field to an [OpenAI chat model](https://platform.openai.com/docs/models). Example: `gpt-4o`.
- Make sure to set the `model-type` field to `OpenAI`.
- If your model supports vision, set the `vision enabled` field to `true`. This is currently only supported for OpenAI models with vision capabilities.
@ -318,7 +318,7 @@ Using Ollama? See the [Ollama Integration](/advanced/ollama) section for more cu
1. Create a new [AI Model API](http://localhost:42110/server/admin/database/aimodelapi/add) in the server admin settings.
- Add your [Anthropic API key](https://console.anthropic.com/account/keys)
- Give the configuration a friendly name like `Anthropic`. Do not configure the API base url.
2. Create a new [chat model options](http://localhost:42110/server/admin/database/chatmodeloptions/add)
2. Create a new [chat model](http://localhost:42110/server/admin/database/chatmodel/add)
- Set the `chat-model` field to an [Anthropic chat model](https://docs.anthropic.com/en/docs/about-claude/models#model-names). Example: `claude-3-5-sonnet-20240620`.
- Set the `model-type` field to `Anthropic`.
- Set the `ai model api` field to the Anthropic AI Model API you created in step 1.
@ -327,7 +327,7 @@ Using Ollama? See the [Ollama Integration](/advanced/ollama) section for more cu
1. Create a new [AI Model API](http://localhost:42110/server/admin/database/aimodelapi/add) in the server admin settings.
- Add your [Gemini API key](https://aistudio.google.com/app/apikey)
- Give the configuration a friendly name like `Gemini`. Do not configure the API base url.
2. Create a new [chat model options](http://localhost:42110/server/admin/database/chatmodeloptions/add)
2. Create a new [chat model](http://localhost:42110/server/admin/database/chatmodel/add)
- Set the `chat-model` field to a [Google Gemini chat model](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#gemini-models). Example: `gemini-1.5-flash`.
- Set the `model-type` field to `Gemini`.
- Set the `ai model api` field to the Gemini AI Model API you created in step 1.
@ -343,7 +343,7 @@ Offline chat stays completely private and can work without internet using any op
:::
1. Get the name of your preferred chat model from [HuggingFace](https://huggingface.co/models?pipeline_tag=text-generation&library=gguf). *Most GGUF format chat models are supported*.
2. Open the [create chat model page](http://localhost:42110/server/admin/database/chatmodeloptions/add/) on the admin panel
2. Open the [create chat model page](http://localhost:42110/server/admin/database/chatmodel/add/) on the admin panel
3. Set the `chat-model` field to the name of your preferred chat model
- Make sure the `model-type` is set to `Offline`
4. Set the newly added chat model as your preferred model in your [User chat settings](http://localhost:42110/settings) and [Server chat settings](http://localhost:42110/server/admin/database/serverchatsettings/).

View file

@ -36,7 +36,7 @@ from torch import Tensor
from khoj.database.models import (
Agent,
AiModelApi,
ChatModelOptions,
ChatModel,
ClientApplication,
Conversation,
Entry,
@ -736,8 +736,8 @@ class AgentAdapters:
@staticmethod
def create_default_agent(user: KhojUser):
default_conversation_config = ConversationAdapters.get_default_conversation_config(user)
if default_conversation_config is None:
default_chat_model = ConversationAdapters.get_default_chat_model(user)
if default_chat_model is None:
logger.info("No default conversation config found, skipping default agent creation")
return None
default_personality = prompts.personality.format(current_date="placeholder", day_of_week="placeholder")
@ -746,7 +746,7 @@ class AgentAdapters:
if agent:
agent.personality = default_personality
agent.chat_model = default_conversation_config
agent.chat_model = default_chat_model
agent.slug = AgentAdapters.DEFAULT_AGENT_SLUG
agent.name = AgentAdapters.DEFAULT_AGENT_NAME
agent.privacy_level = Agent.PrivacyLevel.PUBLIC
@ -760,7 +760,7 @@ class AgentAdapters:
name=AgentAdapters.DEFAULT_AGENT_NAME,
privacy_level=Agent.PrivacyLevel.PUBLIC,
managed_by_admin=True,
chat_model=default_conversation_config,
chat_model=default_chat_model,
personality=default_personality,
slug=AgentAdapters.DEFAULT_AGENT_SLUG,
)
@ -787,7 +787,7 @@ class AgentAdapters:
output_modes: List[str],
slug: Optional[str] = None,
):
chat_model_option = await ChatModelOptions.objects.filter(chat_model=chat_model).afirst()
chat_model_option = await ChatModel.objects.filter(name=chat_model).afirst()
# Slug will be None for new agents, which will trigger a new agent creation with a generated, immutable slug
agent, created = await Agent.objects.filter(slug=slug, creator=user).aupdate_or_create(
@ -972,29 +972,29 @@ class ConversationAdapters:
@staticmethod
@require_valid_user
def has_any_conversation_config(user: KhojUser):
return ChatModelOptions.objects.filter(user=user).exists()
def has_any_chat_model(user: KhojUser):
return ChatModel.objects.filter(user=user).exists()
@staticmethod
def get_all_conversation_configs():
return ChatModelOptions.objects.all()
def get_all_chat_models():
return ChatModel.objects.all()
@staticmethod
async def aget_all_conversation_configs():
return await sync_to_async(list)(ChatModelOptions.objects.prefetch_related("ai_model_api").all())
async def aget_all_chat_models():
return await sync_to_async(list)(ChatModel.objects.prefetch_related("ai_model_api").all())
@staticmethod
def get_vision_enabled_config():
conversation_configurations = ConversationAdapters.get_all_conversation_configs()
for config in conversation_configurations:
chat_models = ConversationAdapters.get_all_chat_models()
for config in chat_models:
if config.vision_enabled:
return config
return None
@staticmethod
async def aget_vision_enabled_config():
conversation_configurations = await ConversationAdapters.aget_all_conversation_configs()
for config in conversation_configurations:
chat_models = await ConversationAdapters.aget_all_chat_models()
for config in chat_models:
if config.vision_enabled:
return config
return None
@ -1010,7 +1010,7 @@ class ConversationAdapters:
@staticmethod
@arequire_valid_user
async def aset_user_conversation_processor(user: KhojUser, conversation_processor_config_id: int):
config = await ChatModelOptions.objects.filter(id=conversation_processor_config_id).afirst()
config = await ChatModel.objects.filter(id=conversation_processor_config_id).afirst()
if not config:
return None
new_config = await UserConversationConfig.objects.aupdate_or_create(user=user, defaults={"setting": config})
@ -1026,24 +1026,24 @@ class ConversationAdapters:
return new_config
@staticmethod
def get_conversation_config(user: KhojUser):
def get_chat_model(user: KhojUser):
subscribed = is_user_subscribed(user)
if not subscribed:
return ConversationAdapters.get_default_conversation_config(user)
return ConversationAdapters.get_default_chat_model(user)
config = UserConversationConfig.objects.filter(user=user).first()
if config:
return config.setting
return ConversationAdapters.get_advanced_conversation_config(user)
return ConversationAdapters.get_advanced_chat_model(user)
@staticmethod
async def aget_conversation_config(user: KhojUser):
async def aget_chat_model(user: KhojUser):
subscribed = await ais_user_subscribed(user)
if not subscribed:
return await ConversationAdapters.aget_default_conversation_config(user)
return await ConversationAdapters.aget_default_chat_model(user)
config = await UserConversationConfig.objects.filter(user=user).prefetch_related("setting").afirst()
if config:
return config.setting
return ConversationAdapters.aget_advanced_conversation_config(user)
return ConversationAdapters.aget_advanced_chat_model(user)
@staticmethod
async def aget_voice_model_config(user: KhojUser) -> Optional[VoiceModelOption]:
@ -1064,7 +1064,7 @@ class ConversationAdapters:
return VoiceModelOption.objects.first()
@staticmethod
def get_default_conversation_config(user: KhojUser = None):
def get_default_chat_model(user: KhojUser = None):
"""Get default conversation config. Prefer chat model by server admin > user > first created chat model"""
# Get the server chat settings
server_chat_settings = ServerChatSettings.objects.first()
@ -1084,10 +1084,10 @@ class ConversationAdapters:
return user_chat_settings.setting
# Get the first chat model if even the user chat settings are not set
return ChatModelOptions.objects.filter().first()
return ChatModel.objects.filter().first()
@staticmethod
async def aget_default_conversation_config(user: KhojUser = None):
async def aget_default_chat_model(user: KhojUser = None):
"""Get default conversation config. Prefer chat model by server admin > user > first created chat model"""
# Get the server chat settings
server_chat_settings: ServerChatSettings = (
@ -1117,17 +1117,17 @@ class ConversationAdapters:
return user_chat_settings.setting
# Get the first chat model if even the user chat settings are not set
return await ChatModelOptions.objects.filter().prefetch_related("ai_model_api").afirst()
return await ChatModel.objects.filter().prefetch_related("ai_model_api").afirst()
@staticmethod
def get_advanced_conversation_config(user: KhojUser):
def get_advanced_chat_model(user: KhojUser):
server_chat_settings = ServerChatSettings.objects.first()
if server_chat_settings is not None and server_chat_settings.chat_advanced is not None:
return server_chat_settings.chat_advanced
return ConversationAdapters.get_default_conversation_config(user)
return ConversationAdapters.get_default_chat_model(user)
@staticmethod
async def aget_advanced_conversation_config(user: KhojUser = None):
async def aget_advanced_chat_model(user: KhojUser = None):
server_chat_settings: ServerChatSettings = (
await ServerChatSettings.objects.filter()
.prefetch_related("chat_advanced", "chat_advanced__ai_model_api")
@ -1135,7 +1135,7 @@ class ConversationAdapters:
)
if server_chat_settings is not None and server_chat_settings.chat_advanced is not None:
return server_chat_settings.chat_advanced
return await ConversationAdapters.aget_default_conversation_config(user)
return await ConversationAdapters.aget_default_chat_model(user)
@staticmethod
async def aget_server_webscraper():
@ -1247,16 +1247,16 @@ class ConversationAdapters:
@staticmethod
def get_conversation_processor_options():
return ChatModelOptions.objects.all()
return ChatModel.objects.all()
@staticmethod
def set_conversation_processor_config(user: KhojUser, new_config: ChatModelOptions):
def set_user_chat_model(user: KhojUser, chat_model: ChatModel):
user_conversation_config, _ = UserConversationConfig.objects.get_or_create(user=user)
user_conversation_config.setting = new_config
user_conversation_config.setting = chat_model
user_conversation_config.save()
@staticmethod
async def aget_user_conversation_config(user: KhojUser):
async def aget_user_chat_model(user: KhojUser):
config = (
await UserConversationConfig.objects.filter(user=user).prefetch_related("setting__ai_model_api").afirst()
)
@ -1288,33 +1288,33 @@ class ConversationAdapters:
return random.sample(all_questions, max_results)
@staticmethod
def get_valid_conversation_config(user: KhojUser, conversation: Conversation):
def get_valid_chat_model(user: KhojUser, conversation: Conversation):
agent: Agent = conversation.agent if AgentAdapters.get_default_agent() != conversation.agent else None
if agent and agent.chat_model:
conversation_config = conversation.agent.chat_model
chat_model = conversation.agent.chat_model
else:
conversation_config = ConversationAdapters.get_conversation_config(user)
chat_model = ConversationAdapters.get_chat_model(user)
if conversation_config is None:
conversation_config = ConversationAdapters.get_default_conversation_config()
if chat_model is None:
chat_model = ConversationAdapters.get_default_chat_model()
if conversation_config.model_type == ChatModelOptions.ModelType.OFFLINE:
if chat_model.model_type == ChatModel.ModelType.OFFLINE:
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
chat_model = conversation_config.chat_model
max_tokens = conversation_config.max_prompt_size
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
chat_model_name = chat_model.name
max_tokens = chat_model.max_prompt_size
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model_name, max_tokens)
return conversation_config
return chat_model
if (
conversation_config.model_type
chat_model.model_type
in [
ChatModelOptions.ModelType.ANTHROPIC,
ChatModelOptions.ModelType.OPENAI,
ChatModelOptions.ModelType.GOOGLE,
ChatModel.ModelType.ANTHROPIC,
ChatModel.ModelType.OPENAI,
ChatModel.ModelType.GOOGLE,
]
) and conversation_config.ai_model_api:
return conversation_config
) and chat_model.ai_model_api:
return chat_model
else:
raise ValueError("Invalid conversation config - either configure offline chat or openai chat")

View file

@ -16,7 +16,7 @@ from unfold import admin as unfold_admin
from khoj.database.models import (
Agent,
AiModelApi,
ChatModelOptions,
ChatModel,
ClientApplication,
Conversation,
Entry,
@ -212,15 +212,15 @@ class KhojUserSubscription(unfold_admin.ModelAdmin):
list_filter = ("type",)
@admin.register(ChatModelOptions)
class ChatModelOptionsAdmin(unfold_admin.ModelAdmin):
@admin.register(ChatModel)
class ChatModelAdmin(unfold_admin.ModelAdmin):
list_display = (
"id",
"chat_model",
"name",
"ai_model_api",
"max_prompt_size",
)
search_fields = ("id", "chat_model", "ai_model_api__name")
search_fields = ("id", "name", "ai_model_api__name")
@admin.register(TextToImageModelConfig)
@ -385,7 +385,7 @@ class UserConversationConfigAdmin(unfold_admin.ModelAdmin):
"get_chat_model",
"get_subscription_type",
)
search_fields = ("id", "user__email", "setting__chat_model", "user__subscription__type")
search_fields = ("id", "user__email", "setting__name", "user__subscription__type")
ordering = ("-updated_at",)
def get_user_email(self, obj):
@ -395,10 +395,10 @@ class UserConversationConfigAdmin(unfold_admin.ModelAdmin):
get_user_email.admin_order_field = "user__email" # type: ignore
def get_chat_model(self, obj):
return obj.setting.chat_model if obj.setting else None
return obj.setting.name if obj.setting else None
get_chat_model.short_description = "Chat Model" # type: ignore
get_chat_model.admin_order_field = "setting__chat_model" # type: ignore
get_chat_model.admin_order_field = "setting__name" # type: ignore
def get_subscription_type(self, obj):
if hasattr(obj.user, "subscription"):

View file

@ -0,0 +1,62 @@
# Generated by Django 5.0.9 on 2024-12-09 04:21
import django.db.models.deletion
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("database", "0076_rename_openaiprocessorconversationconfig_aimodelapi_and_more"),
]
operations = [
migrations.RenameModel(
old_name="ChatModelOptions",
new_name="ChatModel",
),
migrations.RenameField(
model_name="chatmodel",
old_name="chat_model",
new_name="name",
),
migrations.AlterField(
model_name="agent",
name="chat_model",
field=models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to="database.chatmodel"),
),
migrations.AlterField(
model_name="serverchatsettings",
name="chat_advanced",
field=models.ForeignKey(
blank=True,
default=None,
null=True,
on_delete=django.db.models.deletion.CASCADE,
related_name="chat_advanced",
to="database.chatmodel",
),
),
migrations.AlterField(
model_name="serverchatsettings",
name="chat_default",
field=models.ForeignKey(
blank=True,
default=None,
null=True,
on_delete=django.db.models.deletion.CASCADE,
related_name="chat_default",
to="database.chatmodel",
),
),
migrations.AlterField(
model_name="userconversationconfig",
name="setting",
field=models.ForeignKey(
blank=True,
default=None,
null=True,
on_delete=django.db.models.deletion.CASCADE,
to="database.chatmodel",
),
),
]

View file

@ -193,7 +193,7 @@ class AiModelApi(DbBaseModel):
return self.name
class ChatModelOptions(DbBaseModel):
class ChatModel(DbBaseModel):
class ModelType(models.TextChoices):
OPENAI = "openai"
OFFLINE = "offline"
@ -203,13 +203,13 @@ class ChatModelOptions(DbBaseModel):
max_prompt_size = models.IntegerField(default=None, null=True, blank=True)
subscribed_max_prompt_size = models.IntegerField(default=None, null=True, blank=True)
tokenizer = models.CharField(max_length=200, default=None, null=True, blank=True)
chat_model = models.CharField(max_length=200, default="bartowski/Meta-Llama-3.1-8B-Instruct-GGUF")
name = models.CharField(max_length=200, default="bartowski/Meta-Llama-3.1-8B-Instruct-GGUF")
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE)
vision_enabled = models.BooleanField(default=False)
ai_model_api = models.ForeignKey(AiModelApi, on_delete=models.CASCADE, default=None, null=True, blank=True)
def __str__(self):
return self.chat_model
return self.name
class VoiceModelOption(DbBaseModel):
@ -297,7 +297,7 @@ class Agent(DbBaseModel):
models.CharField(max_length=200, choices=OutputModeOptions.choices), default=list, null=True, blank=True
)
managed_by_admin = models.BooleanField(default=False)
chat_model = models.ForeignKey(ChatModelOptions, on_delete=models.CASCADE)
chat_model = models.ForeignKey(ChatModel, on_delete=models.CASCADE)
slug = models.CharField(max_length=200, unique=True)
style_color = models.CharField(max_length=200, choices=StyleColorTypes.choices, default=StyleColorTypes.BLUE)
style_icon = models.CharField(max_length=200, choices=StyleIconTypes.choices, default=StyleIconTypes.LIGHTBULB)
@ -438,10 +438,10 @@ class WebScraper(DbBaseModel):
class ServerChatSettings(DbBaseModel):
chat_default = models.ForeignKey(
ChatModelOptions, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="chat_default"
ChatModel, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="chat_default"
)
chat_advanced = models.ForeignKey(
ChatModelOptions, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="chat_advanced"
ChatModel, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="chat_advanced"
)
web_scraper = models.ForeignKey(
WebScraper, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="web_scraper"
@ -563,7 +563,7 @@ class SpeechToTextModelOptions(DbBaseModel):
class UserConversationConfig(DbBaseModel):
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
setting = models.ForeignKey(ChatModelOptions, on_delete=models.CASCADE, default=None, null=True, blank=True)
setting = models.ForeignKey(ChatModel, on_delete=models.CASCADE, default=None, null=True, blank=True)
class UserVoiceModelConfig(DbBaseModel):

View file

@ -60,7 +60,7 @@ import logging
from packaging import version
from khoj.database.models import AiModelApi, ChatModelOptions, SearchModelConfig
from khoj.database.models import AiModelApi, ChatModel, SearchModelConfig
from khoj.utils.yaml import load_config_from_file, save_config_to_file
logger = logging.getLogger(__name__)
@ -98,11 +98,11 @@ def migrate_server_pg(args):
if "offline-chat" in raw_config["processor"]["conversation"]:
offline_chat = raw_config["processor"]["conversation"]["offline-chat"]
ChatModelOptions.objects.create(
chat_model=offline_chat.get("chat-model"),
ChatModel.objects.create(
name=offline_chat.get("chat-model"),
tokenizer=processor_conversation.get("tokenizer"),
max_prompt_size=processor_conversation.get("max-prompt-size"),
model_type=ChatModelOptions.ModelType.OFFLINE,
model_type=ChatModel.ModelType.OFFLINE,
)
if (
@ -119,11 +119,11 @@ def migrate_server_pg(args):
openai_model_api = AiModelApi.objects.create(api_key=openai.get("api-key"), name="default")
ChatModelOptions.objects.create(
chat_model=openai.get("chat-model"),
ChatModel.objects.create(
name=openai.get("chat-model"),
tokenizer=processor_conversation.get("tokenizer"),
max_prompt_size=processor_conversation.get("max-prompt-size"),
model_type=ChatModelOptions.ModelType.OPENAI,
model_type=ChatModel.ModelType.OPENAI,
ai_model_api=openai_model_api,
)

View file

@ -5,7 +5,7 @@ from typing import Dict, List, Optional
import pyjson5
from langchain.schema import ChatMessage
from khoj.database.models import Agent, ChatModelOptions, KhojUser
from khoj.database.models import Agent, ChatModel, KhojUser
from khoj.processor.conversation import prompts
from khoj.processor.conversation.anthropic.utils import (
anthropic_chat_completion_with_backoff,
@ -85,7 +85,7 @@ def extract_questions_anthropic(
prompt = construct_structured_message(
message=prompt,
images=query_images,
model_type=ChatModelOptions.ModelType.ANTHROPIC,
model_type=ChatModel.ModelType.ANTHROPIC,
vision_enabled=vision_enabled,
attached_file_context=query_files,
)
@ -218,7 +218,7 @@ def converse_anthropic(
tokenizer_name=tokenizer_name,
query_images=query_images,
vision_enabled=vision_available,
model_type=ChatModelOptions.ModelType.ANTHROPIC,
model_type=ChatModel.ModelType.ANTHROPIC,
query_files=query_files,
generated_files=generated_files,
generated_asset_results=generated_asset_results,

View file

@ -5,7 +5,7 @@ from typing import Dict, List, Optional
import pyjson5
from langchain.schema import ChatMessage
from khoj.database.models import Agent, ChatModelOptions, KhojUser
from khoj.database.models import Agent, ChatModel, KhojUser
from khoj.processor.conversation import prompts
from khoj.processor.conversation.google.utils import (
format_messages_for_gemini,
@ -86,7 +86,7 @@ def extract_questions_gemini(
prompt = construct_structured_message(
message=prompt,
images=query_images,
model_type=ChatModelOptions.ModelType.GOOGLE,
model_type=ChatModel.ModelType.GOOGLE,
vision_enabled=vision_enabled,
attached_file_context=query_files,
)
@ -229,7 +229,7 @@ def converse_gemini(
tokenizer_name=tokenizer_name,
query_images=query_images,
vision_enabled=vision_available,
model_type=ChatModelOptions.ModelType.GOOGLE,
model_type=ChatModel.ModelType.GOOGLE,
query_files=query_files,
generated_files=generated_files,
generated_asset_results=generated_asset_results,

View file

@ -9,7 +9,7 @@ import pyjson5
from langchain.schema import ChatMessage
from llama_cpp import Llama
from khoj.database.models import Agent, ChatModelOptions, KhojUser
from khoj.database.models import Agent, ChatModel, KhojUser
from khoj.processor.conversation import prompts
from khoj.processor.conversation.offline.utils import download_model
from khoj.processor.conversation.utils import (
@ -96,7 +96,7 @@ def extract_questions_offline(
model_name=model,
loaded_model=offline_chat_model,
max_prompt_size=max_prompt_size,
model_type=ChatModelOptions.ModelType.OFFLINE,
model_type=ChatModel.ModelType.OFFLINE,
query_files=query_files,
)
@ -105,7 +105,7 @@ def extract_questions_offline(
response = send_message_to_model_offline(
messages,
loaded_model=offline_chat_model,
model=model,
model_name=model,
max_prompt_size=max_prompt_size,
temperature=temperature,
response_type="json_object",
@ -154,7 +154,7 @@ def converse_offline(
online_results={},
code_results={},
conversation_log={},
model: str = "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF",
model_name: str = "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF",
loaded_model: Union[Any, None] = None,
completion_func=None,
conversation_commands=[ConversationCommand.Default],
@ -174,8 +174,8 @@ def converse_offline(
"""
# Initialize Variables
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
tracer["chat_model"] = model
offline_chat_model = loaded_model or download_model(model_name, max_tokens=max_prompt_size)
tracer["chat_model"] = model_name
current_date = datetime.now()
if agent and agent.personality:
@ -228,18 +228,18 @@ def converse_offline(
system_prompt,
conversation_log,
context_message=context_message,
model_name=model,
model_name=model_name,
loaded_model=offline_chat_model,
max_prompt_size=max_prompt_size,
tokenizer_name=tokenizer_name,
model_type=ChatModelOptions.ModelType.OFFLINE,
model_type=ChatModel.ModelType.OFFLINE,
query_files=query_files,
generated_files=generated_files,
generated_asset_results=generated_asset_results,
program_execution_context=additional_context,
)
logger.debug(f"Conversation Context for {model}: {messages_to_print(messages)}")
logger.debug(f"Conversation Context for {model_name}: {messages_to_print(messages)}")
g = ThreadedGenerator(references, online_results, completion_func=completion_func)
t = Thread(target=llm_thread, args=(g, messages, offline_chat_model, max_prompt_size, tracer))
@ -273,7 +273,7 @@ def llm_thread(g, messages: List[ChatMessage], model: Any, max_prompt_size: int
def send_message_to_model_offline(
messages: List[ChatMessage],
loaded_model=None,
model="bartowski/Meta-Llama-3.1-8B-Instruct-GGUF",
model_name="bartowski/Meta-Llama-3.1-8B-Instruct-GGUF",
temperature: float = 0.2,
streaming=False,
stop=[],
@ -282,7 +282,7 @@ def send_message_to_model_offline(
tracer: dict = {},
):
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
offline_chat_model = loaded_model or download_model(model_name, max_tokens=max_prompt_size)
messages_dict = [{"role": message.role, "content": message.content} for message in messages]
seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None
response = offline_chat_model.create_chat_completion(
@ -301,7 +301,7 @@ def send_message_to_model_offline(
# Save conversation trace for non-streaming responses
# Streamed responses need to be saved by the calling function
tracer["chat_model"] = model
tracer["chat_model"] = model_name
tracer["temperature"] = temperature
if is_promptrace_enabled():
commit_conversation_trace(messages, response_text, tracer)

View file

@ -5,7 +5,7 @@ from typing import Dict, List, Optional
import pyjson5
from langchain.schema import ChatMessage
from khoj.database.models import Agent, ChatModelOptions, KhojUser
from khoj.database.models import Agent, ChatModel, KhojUser
from khoj.processor.conversation import prompts
from khoj.processor.conversation.openai.utils import (
chat_completion_with_backoff,
@ -83,7 +83,7 @@ def extract_questions(
prompt = construct_structured_message(
message=prompt,
images=query_images,
model_type=ChatModelOptions.ModelType.OPENAI,
model_type=ChatModel.ModelType.OPENAI,
vision_enabled=vision_enabled,
attached_file_context=query_files,
)
@ -128,7 +128,7 @@ def send_message_to_model(
# Get Response from GPT
return completion_with_backoff(
messages=messages,
model=model,
model_name=model,
openai_api_key=api_key,
temperature=temperature,
api_base_url=api_base_url,
@ -220,7 +220,7 @@ def converse_openai(
tokenizer_name=tokenizer_name,
query_images=query_images,
vision_enabled=vision_available,
model_type=ChatModelOptions.ModelType.OPENAI,
model_type=ChatModel.ModelType.OPENAI,
query_files=query_files,
generated_files=generated_files,
generated_asset_results=generated_asset_results,

View file

@ -40,7 +40,13 @@ openai_clients: Dict[str, openai.OpenAI] = {}
reraise=True,
)
def completion_with_backoff(
messages, model, temperature=0, openai_api_key=None, api_base_url=None, model_kwargs=None, tracer: dict = {}
messages,
model_name: str,
temperature=0,
openai_api_key=None,
api_base_url=None,
model_kwargs: dict = {},
tracer: dict = {},
) -> str:
client_key = f"{openai_api_key}--{api_base_url}"
client: openai.OpenAI | None = openai_clients.get(client_key)
@ -56,7 +62,7 @@ def completion_with_backoff(
# Update request parameters for compatability with o1 model series
# Refer: https://platform.openai.com/docs/guides/reasoning/beta-limitations
if model.startswith("o1"):
if model_name.startswith("o1"):
temperature = 1
model_kwargs.pop("stop", None)
model_kwargs.pop("response_format", None)
@ -66,12 +72,12 @@ def completion_with_backoff(
chat: ChatCompletion | openai.Stream[ChatCompletionChunk] = client.chat.completions.create(
messages=formatted_messages, # type: ignore
model=model, # type: ignore
model=model_name, # type: ignore
stream=stream,
stream_options={"include_usage": True} if stream else {},
temperature=temperature,
timeout=20,
**(model_kwargs or dict()),
**model_kwargs,
)
aggregated_response = ""
@ -91,10 +97,11 @@ def completion_with_backoff(
# Calculate cost of chat
input_tokens = chunk.usage.prompt_tokens if hasattr(chunk, "usage") and chunk.usage else 0
output_tokens = chunk.usage.completion_tokens if hasattr(chunk, "usage") and chunk.usage else 0
tracer["usage"] = get_chat_usage_metrics(model, input_tokens, output_tokens, tracer.get("usage"))
cost = chunk.usage.model_extra.get("estimated_cost") or 0 # Estimated costs returned by DeepInfra API
tracer["usage"] = get_chat_usage_metrics(model_name, input_tokens, output_tokens, tracer.get("usage"), cost)
# Save conversation trace
tracer["chat_model"] = model
tracer["chat_model"] = model_name
tracer["temperature"] = temperature
if is_promptrace_enabled():
commit_conversation_trace(messages, aggregated_response, tracer)
@ -139,11 +146,11 @@ def chat_completion_with_backoff(
def llm_thread(
g,
messages,
model_name,
model_name: str,
temperature,
openai_api_key=None,
api_base_url=None,
model_kwargs=None,
model_kwargs: dict = {},
tracer: dict = {},
):
try:
@ -177,7 +184,7 @@ def llm_thread(
stream_options={"include_usage": True} if stream else {},
temperature=temperature,
timeout=20,
**(model_kwargs or dict()),
**model_kwargs,
)
aggregated_response = ""
@ -202,7 +209,8 @@ def llm_thread(
# Calculate cost of chat
input_tokens = chunk.usage.prompt_tokens if hasattr(chunk, "usage") and chunk.usage else 0
output_tokens = chunk.usage.completion_tokens if hasattr(chunk, "usage") and chunk.usage else 0
tracer["usage"] = get_chat_usage_metrics(model_name, input_tokens, output_tokens, tracer.get("usage"))
cost = chunk.usage.model_extra.get("estimated_cost") or 0 # Estimated costs returned by DeepInfra API
tracer["usage"] = get_chat_usage_metrics(model_name, input_tokens, output_tokens, tracer.get("usage"), cost)
# Save conversation trace
tracer["chat_model"] = model_name

View file

@ -24,7 +24,7 @@ from llama_cpp.llama import Llama
from transformers import AutoTokenizer
from khoj.database.adapters import ConversationAdapters
from khoj.database.models import ChatModelOptions, ClientApplication, KhojUser
from khoj.database.models import ChatModel, ClientApplication, KhojUser
from khoj.processor.conversation import prompts
from khoj.processor.conversation.offline.utils import download_model, infer_max_tokens
from khoj.search_filter.base_filter import BaseFilter
@ -330,9 +330,9 @@ def construct_structured_message(
Format messages into appropriate multimedia format for supported chat model types
"""
if model_type in [
ChatModelOptions.ModelType.OPENAI,
ChatModelOptions.ModelType.GOOGLE,
ChatModelOptions.ModelType.ANTHROPIC,
ChatModel.ModelType.OPENAI,
ChatModel.ModelType.GOOGLE,
ChatModel.ModelType.ANTHROPIC,
]:
if not attached_file_context and not (vision_enabled and images):
return message

View file

@ -28,12 +28,7 @@ from khoj.database.adapters import (
get_default_search_model,
get_user_photo,
)
from khoj.database.models import (
Agent,
ChatModelOptions,
KhojUser,
SpeechToTextModelOptions,
)
from khoj.database.models import Agent, ChatModel, KhojUser, SpeechToTextModelOptions
from khoj.processor.conversation import prompts
from khoj.processor.conversation.anthropic.anthropic_chat import (
extract_questions_anthropic,
@ -404,15 +399,15 @@ async def extract_references_and_questions(
# Infer search queries from user message
with timer("Extracting search queries took", logger):
# If we've reached here, either the user has enabled offline chat or the openai model is enabled.
conversation_config = await ConversationAdapters.aget_default_conversation_config(user)
vision_enabled = conversation_config.vision_enabled
chat_model = await ConversationAdapters.aget_default_chat_model(user)
vision_enabled = chat_model.vision_enabled
if conversation_config.model_type == ChatModelOptions.ModelType.OFFLINE:
if chat_model.model_type == ChatModel.ModelType.OFFLINE:
using_offline_chat = True
chat_model = conversation_config.chat_model
max_tokens = conversation_config.max_prompt_size
chat_model_name = chat_model.name
max_tokens = chat_model.max_prompt_size
if state.offline_chat_processor_config is None:
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model_name, max_tokens)
loaded_model = state.offline_chat_processor_config.loaded_model
@ -424,18 +419,18 @@ async def extract_references_and_questions(
should_extract_questions=True,
location_data=location_data,
user=user,
max_prompt_size=conversation_config.max_prompt_size,
max_prompt_size=chat_model.max_prompt_size,
personality_context=personality_context,
query_files=query_files,
tracer=tracer,
)
elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
api_key = conversation_config.ai_model_api.api_key
base_url = conversation_config.ai_model_api.api_base_url
chat_model = conversation_config.chat_model
elif chat_model.model_type == ChatModel.ModelType.OPENAI:
api_key = chat_model.ai_model_api.api_key
base_url = chat_model.ai_model_api.api_base_url
chat_model_name = chat_model.name
inferred_queries = extract_questions(
defiltered_query,
model=chat_model,
model=chat_model_name,
api_key=api_key,
api_base_url=base_url,
conversation_log=meta_log,
@ -447,13 +442,13 @@ async def extract_references_and_questions(
query_files=query_files,
tracer=tracer,
)
elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC:
api_key = conversation_config.ai_model_api.api_key
chat_model = conversation_config.chat_model
elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC:
api_key = chat_model.ai_model_api.api_key
chat_model_name = chat_model.name
inferred_queries = extract_questions_anthropic(
defiltered_query,
query_images=query_images,
model=chat_model,
model=chat_model_name,
api_key=api_key,
conversation_log=meta_log,
location_data=location_data,
@ -463,17 +458,17 @@ async def extract_references_and_questions(
query_files=query_files,
tracer=tracer,
)
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
api_key = conversation_config.ai_model_api.api_key
chat_model = conversation_config.chat_model
elif chat_model.model_type == ChatModel.ModelType.GOOGLE:
api_key = chat_model.ai_model_api.api_key
chat_model_name = chat_model.name
inferred_queries = extract_questions_gemini(
defiltered_query,
query_images=query_images,
model=chat_model,
model=chat_model_name,
api_key=api_key,
conversation_log=meta_log,
location_data=location_data,
max_tokens=conversation_config.max_prompt_size,
max_tokens=chat_model.max_prompt_size,
user=user,
vision_enabled=vision_enabled,
personality_context=personality_context,

View file

@ -62,7 +62,7 @@ async def all_agents(
"color": agent.style_color,
"icon": agent.style_icon,
"privacy_level": agent.privacy_level,
"chat_model": agent.chat_model.chat_model,
"chat_model": agent.chat_model.name,
"files": file_names,
"input_tools": agent.input_tools,
"output_modes": agent.output_modes,
@ -150,7 +150,7 @@ async def get_agent(
"color": agent.style_color,
"icon": agent.style_icon,
"privacy_level": agent.privacy_level,
"chat_model": agent.chat_model.chat_model,
"chat_model": agent.chat_model.name,
"files": file_names,
"input_tools": agent.input_tools,
"output_modes": agent.output_modes,
@ -225,7 +225,7 @@ async def create_agent(
"color": agent.style_color,
"icon": agent.style_icon,
"privacy_level": agent.privacy_level,
"chat_model": agent.chat_model.chat_model,
"chat_model": agent.chat_model.name,
"files": body.files,
"input_tools": agent.input_tools,
"output_modes": agent.output_modes,
@ -286,7 +286,7 @@ async def update_agent(
"color": agent.style_color,
"icon": agent.style_icon,
"privacy_level": agent.privacy_level,
"chat_model": agent.chat_model.chat_model,
"chat_model": agent.chat_model.name,
"files": body.files,
"input_tools": agent.input_tools,
"output_modes": agent.output_modes,

View file

@ -58,7 +58,7 @@ from khoj.routers.helpers import (
is_ready_to_chat,
read_chat_stream,
update_telemetry_state,
validate_conversation_config,
validate_chat_model,
)
from khoj.routers.research import (
InformationCollectionIteration,
@ -205,7 +205,7 @@ def chat_history(
n: Optional[int] = None,
):
user = request.user.object
validate_conversation_config(user)
validate_chat_model(user)
# Load Conversation History
conversation = ConversationAdapters.get_conversation_by_user(
@ -898,10 +898,10 @@ async def chat(
custom_filters = []
if conversation_commands == [ConversationCommand.Help]:
if not q:
conversation_config = await ConversationAdapters.aget_user_conversation_config(user)
if conversation_config == None:
conversation_config = await ConversationAdapters.aget_default_conversation_config(user)
model_type = conversation_config.model_type
chat_model = await ConversationAdapters.aget_user_chat_model(user)
if chat_model == None:
chat_model = await ConversationAdapters.aget_default_chat_model(user)
model_type = chat_model.model_type
formatted_help = help_message.format(model=model_type, version=state.khoj_version, device=get_device())
async for result in send_llm_response(formatted_help, tracer.get("usage")):
yield result

View file

@ -24,7 +24,7 @@ def get_chat_model_options(
all_conversation_options = list()
for conversation_option in conversation_options:
all_conversation_options.append({"chat_model": conversation_option.chat_model, "id": conversation_option.id})
all_conversation_options.append({"chat_model": conversation_option.name, "id": conversation_option.id})
return Response(content=json.dumps(all_conversation_options), media_type="application/json", status_code=200)
@ -37,12 +37,12 @@ def get_user_chat_model(
):
user = request.user.object
chat_model = ConversationAdapters.get_conversation_config(user)
chat_model = ConversationAdapters.get_chat_model(user)
if chat_model is None:
chat_model = ConversationAdapters.get_default_conversation_config(user)
chat_model = ConversationAdapters.get_default_chat_model(user)
return Response(status_code=200, content=json.dumps({"id": chat_model.id, "chat_model": chat_model.chat_model}))
return Response(status_code=200, content=json.dumps({"id": chat_model.id, "chat_model": chat_model.name}))
@api_model.post("/chat", status_code=200)

View file

@ -56,7 +56,7 @@ from khoj.database.adapters import (
)
from khoj.database.models import (
Agent,
ChatModelOptions,
ChatModel,
ClientApplication,
Conversation,
GithubConfig,
@ -133,40 +133,40 @@ def is_query_empty(query: str) -> bool:
return is_none_or_empty(query.strip())
def validate_conversation_config(user: KhojUser):
default_config = ConversationAdapters.get_default_conversation_config(user)
def validate_chat_model(user: KhojUser):
default_chat_model = ConversationAdapters.get_default_chat_model(user)
if default_config is None:
if default_chat_model is None:
raise HTTPException(status_code=500, detail="Contact the server administrator to add a chat model.")
if default_config.model_type == "openai" and not default_config.ai_model_api:
if default_chat_model.model_type == "openai" and not default_chat_model.ai_model_api:
raise HTTPException(status_code=500, detail="Contact the server administrator to add a chat model.")
async def is_ready_to_chat(user: KhojUser):
user_conversation_config = await ConversationAdapters.aget_user_conversation_config(user)
if user_conversation_config == None:
user_conversation_config = await ConversationAdapters.aget_default_conversation_config(user)
user_chat_model = await ConversationAdapters.aget_user_chat_model(user)
if user_chat_model == None:
user_chat_model = await ConversationAdapters.aget_default_chat_model(user)
if user_conversation_config and user_conversation_config.model_type == ChatModelOptions.ModelType.OFFLINE:
chat_model = user_conversation_config.chat_model
max_tokens = user_conversation_config.max_prompt_size
if user_chat_model and user_chat_model.model_type == ChatModel.ModelType.OFFLINE:
chat_model_name = user_chat_model.name
max_tokens = user_chat_model.max_prompt_size
if state.offline_chat_processor_config is None:
logger.info("Loading Offline Chat Model...")
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model_name, max_tokens)
return True
if (
user_conversation_config
user_chat_model
and (
user_conversation_config.model_type
user_chat_model.model_type
in [
ChatModelOptions.ModelType.OPENAI,
ChatModelOptions.ModelType.ANTHROPIC,
ChatModelOptions.ModelType.GOOGLE,
ChatModel.ModelType.OPENAI,
ChatModel.ModelType.ANTHROPIC,
ChatModel.ModelType.GOOGLE,
]
)
and user_conversation_config.ai_model_api
and user_chat_model.ai_model_api
):
return True
@ -942,120 +942,124 @@ async def send_message_to_model_wrapper(
query_files: str = None,
tracer: dict = {},
):
conversation_config: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config(user)
vision_available = conversation_config.vision_enabled
chat_model: ChatModel = await ConversationAdapters.aget_default_chat_model(user)
vision_available = chat_model.vision_enabled
if not vision_available and query_images:
logger.warning(f"Vision is not enabled for default model: {conversation_config.chat_model}.")
logger.warning(f"Vision is not enabled for default model: {chat_model.name}.")
vision_enabled_config = await ConversationAdapters.aget_vision_enabled_config()
if vision_enabled_config:
conversation_config = vision_enabled_config
chat_model = vision_enabled_config
vision_available = True
if vision_available and query_images:
logger.info(f"Using {conversation_config.chat_model} model to understand {len(query_images)} images.")
logger.info(f"Using {chat_model.name} model to understand {len(query_images)} images.")
subscribed = await ais_user_subscribed(user)
chat_model = conversation_config.chat_model
chat_model_name = chat_model.name
max_tokens = (
conversation_config.subscribed_max_prompt_size
if subscribed and conversation_config.subscribed_max_prompt_size
else conversation_config.max_prompt_size
chat_model.subscribed_max_prompt_size
if subscribed and chat_model.subscribed_max_prompt_size
else chat_model.max_prompt_size
)
tokenizer = conversation_config.tokenizer
model_type = conversation_config.model_type
vision_available = conversation_config.vision_enabled
tokenizer = chat_model.tokenizer
model_type = chat_model.model_type
vision_available = chat_model.vision_enabled
if model_type == ChatModelOptions.ModelType.OFFLINE:
if model_type == ChatModel.ModelType.OFFLINE:
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model_name, max_tokens)
loaded_model = state.offline_chat_processor_config.loaded_model
truncated_messages = generate_chatml_messages_with_context(
user_message=query,
context_message=context,
system_message=system_message,
model_name=chat_model,
model_name=chat_model_name,
loaded_model=loaded_model,
tokenizer_name=tokenizer,
max_prompt_size=max_tokens,
vision_enabled=vision_available,
model_type=conversation_config.model_type,
model_type=chat_model.model_type,
query_files=query_files,
)
return send_message_to_model_offline(
messages=truncated_messages,
loaded_model=loaded_model,
model=chat_model,
model_name=chat_model_name,
max_prompt_size=max_tokens,
streaming=False,
response_type=response_type,
tracer=tracer,
)
elif model_type == ChatModelOptions.ModelType.OPENAI:
openai_chat_config = conversation_config.ai_model_api
elif model_type == ChatModel.ModelType.OPENAI:
openai_chat_config = chat_model.ai_model_api
api_key = openai_chat_config.api_key
api_base_url = openai_chat_config.api_base_url
truncated_messages = generate_chatml_messages_with_context(
user_message=query,
context_message=context,
system_message=system_message,
model_name=chat_model,
model_name=chat_model_name,
max_prompt_size=max_tokens,
tokenizer_name=tokenizer,
vision_enabled=vision_available,
query_images=query_images,
model_type=conversation_config.model_type,
model_type=chat_model.model_type,
query_files=query_files,
)
return send_message_to_model(
messages=truncated_messages,
api_key=api_key,
model=chat_model,
model=chat_model_name,
response_type=response_type,
api_base_url=api_base_url,
tracer=tracer,
)
elif model_type == ChatModelOptions.ModelType.ANTHROPIC:
api_key = conversation_config.ai_model_api.api_key
elif model_type == ChatModel.ModelType.ANTHROPIC:
api_key = chat_model.ai_model_api.api_key
truncated_messages = generate_chatml_messages_with_context(
user_message=query,
context_message=context,
system_message=system_message,
model_name=chat_model,
model_name=chat_model_name,
max_prompt_size=max_tokens,
tokenizer_name=tokenizer,
vision_enabled=vision_available,
query_images=query_images,
model_type=conversation_config.model_type,
model_type=chat_model.model_type,
query_files=query_files,
)
return anthropic_send_message_to_model(
messages=truncated_messages,
api_key=api_key,
model=chat_model,
model=chat_model_name,
response_type=response_type,
tracer=tracer,
)
elif model_type == ChatModelOptions.ModelType.GOOGLE:
api_key = conversation_config.ai_model_api.api_key
elif model_type == ChatModel.ModelType.GOOGLE:
api_key = chat_model.ai_model_api.api_key
truncated_messages = generate_chatml_messages_with_context(
user_message=query,
context_message=context,
system_message=system_message,
model_name=chat_model,
model_name=chat_model_name,
max_prompt_size=max_tokens,
tokenizer_name=tokenizer,
vision_enabled=vision_available,
query_images=query_images,
model_type=conversation_config.model_type,
model_type=chat_model.model_type,
query_files=query_files,
)
return gemini_send_message_to_model(
messages=truncated_messages, api_key=api_key, model=chat_model, response_type=response_type, tracer=tracer
messages=truncated_messages,
api_key=api_key,
model=chat_model_name,
response_type=response_type,
tracer=tracer,
)
else:
raise HTTPException(status_code=500, detail="Invalid conversation config")
@ -1069,99 +1073,99 @@ def send_message_to_model_wrapper_sync(
query_files: str = "",
tracer: dict = {},
):
conversation_config: ChatModelOptions = ConversationAdapters.get_default_conversation_config(user)
chat_model: ChatModel = ConversationAdapters.get_default_chat_model(user)
if conversation_config is None:
if chat_model is None:
raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.")
chat_model = conversation_config.chat_model
max_tokens = conversation_config.max_prompt_size
vision_available = conversation_config.vision_enabled
chat_model_name = chat_model.name
max_tokens = chat_model.max_prompt_size
vision_available = chat_model.vision_enabled
if conversation_config.model_type == ChatModelOptions.ModelType.OFFLINE:
if chat_model.model_type == ChatModel.ModelType.OFFLINE:
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model_name, max_tokens)
loaded_model = state.offline_chat_processor_config.loaded_model
truncated_messages = generate_chatml_messages_with_context(
user_message=message,
system_message=system_message,
model_name=chat_model,
model_name=chat_model_name,
loaded_model=loaded_model,
max_prompt_size=max_tokens,
vision_enabled=vision_available,
model_type=conversation_config.model_type,
model_type=chat_model.model_type,
query_files=query_files,
)
return send_message_to_model_offline(
messages=truncated_messages,
loaded_model=loaded_model,
model=chat_model,
model_name=chat_model_name,
max_prompt_size=max_tokens,
streaming=False,
response_type=response_type,
tracer=tracer,
)
elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
api_key = conversation_config.ai_model_api.api_key
elif chat_model.model_type == ChatModel.ModelType.OPENAI:
api_key = chat_model.ai_model_api.api_key
truncated_messages = generate_chatml_messages_with_context(
user_message=message,
system_message=system_message,
model_name=chat_model,
model_name=chat_model_name,
max_prompt_size=max_tokens,
vision_enabled=vision_available,
model_type=conversation_config.model_type,
model_type=chat_model.model_type,
query_files=query_files,
)
openai_response = send_message_to_model(
messages=truncated_messages,
api_key=api_key,
model=chat_model,
model=chat_model_name,
response_type=response_type,
tracer=tracer,
)
return openai_response
elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC:
api_key = conversation_config.ai_model_api.api_key
elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC:
api_key = chat_model.ai_model_api.api_key
truncated_messages = generate_chatml_messages_with_context(
user_message=message,
system_message=system_message,
model_name=chat_model,
model_name=chat_model_name,
max_prompt_size=max_tokens,
vision_enabled=vision_available,
model_type=conversation_config.model_type,
model_type=chat_model.model_type,
query_files=query_files,
)
return anthropic_send_message_to_model(
messages=truncated_messages,
api_key=api_key,
model=chat_model,
model=chat_model_name,
response_type=response_type,
tracer=tracer,
)
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
api_key = conversation_config.ai_model_api.api_key
elif chat_model.model_type == ChatModel.ModelType.GOOGLE:
api_key = chat_model.ai_model_api.api_key
truncated_messages = generate_chatml_messages_with_context(
user_message=message,
system_message=system_message,
model_name=chat_model,
model_name=chat_model_name,
max_prompt_size=max_tokens,
vision_enabled=vision_available,
model_type=conversation_config.model_type,
model_type=chat_model.model_type,
query_files=query_files,
)
return gemini_send_message_to_model(
messages=truncated_messages,
api_key=api_key,
model=chat_model,
model=chat_model_name,
response_type=response_type,
tracer=tracer,
)
@ -1229,15 +1233,15 @@ def generate_chat_response(
online_results = {}
code_results = {}
conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation)
vision_available = conversation_config.vision_enabled
chat_model = ConversationAdapters.get_valid_chat_model(user, conversation)
vision_available = chat_model.vision_enabled
if not vision_available and query_images:
vision_enabled_config = ConversationAdapters.get_vision_enabled_config()
if vision_enabled_config:
conversation_config = vision_enabled_config
chat_model = vision_enabled_config
vision_available = True
if conversation_config.model_type == "offline":
if chat_model.model_type == "offline":
loaded_model = state.offline_chat_processor_config.loaded_model
chat_response = converse_offline(
user_query=query_to_run,
@ -1247,9 +1251,9 @@ def generate_chat_response(
conversation_log=meta_log,
completion_func=partial_completion,
conversation_commands=conversation_commands,
model=conversation_config.chat_model,
max_prompt_size=conversation_config.max_prompt_size,
tokenizer_name=conversation_config.tokenizer,
model_name=chat_model.name,
max_prompt_size=chat_model.max_prompt_size,
tokenizer_name=chat_model.tokenizer,
location_data=location_data,
user_name=user_name,
agent=agent,
@ -1259,10 +1263,10 @@ def generate_chat_response(
tracer=tracer,
)
elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
openai_chat_config = conversation_config.ai_model_api
elif chat_model.model_type == ChatModel.ModelType.OPENAI:
openai_chat_config = chat_model.ai_model_api
api_key = openai_chat_config.api_key
chat_model = conversation_config.chat_model
chat_model_name = chat_model.name
chat_response = converse_openai(
compiled_references,
query_to_run,
@ -1270,13 +1274,13 @@ def generate_chat_response(
online_results=online_results,
code_results=code_results,
conversation_log=meta_log,
model=chat_model,
model=chat_model_name,
api_key=api_key,
api_base_url=openai_chat_config.api_base_url,
completion_func=partial_completion,
conversation_commands=conversation_commands,
max_prompt_size=conversation_config.max_prompt_size,
tokenizer_name=conversation_config.tokenizer,
max_prompt_size=chat_model.max_prompt_size,
tokenizer_name=chat_model.tokenizer,
location_data=location_data,
user_name=user_name,
agent=agent,
@ -1288,8 +1292,8 @@ def generate_chat_response(
tracer=tracer,
)
elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC:
api_key = conversation_config.ai_model_api.api_key
elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC:
api_key = chat_model.ai_model_api.api_key
chat_response = converse_anthropic(
compiled_references,
query_to_run,
@ -1297,12 +1301,12 @@ def generate_chat_response(
online_results=online_results,
code_results=code_results,
conversation_log=meta_log,
model=conversation_config.chat_model,
model=chat_model.name,
api_key=api_key,
completion_func=partial_completion,
conversation_commands=conversation_commands,
max_prompt_size=conversation_config.max_prompt_size,
tokenizer_name=conversation_config.tokenizer,
max_prompt_size=chat_model.max_prompt_size,
tokenizer_name=chat_model.tokenizer,
location_data=location_data,
user_name=user_name,
agent=agent,
@ -1313,20 +1317,20 @@ def generate_chat_response(
program_execution_context=program_execution_context,
tracer=tracer,
)
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
api_key = conversation_config.ai_model_api.api_key
elif chat_model.model_type == ChatModel.ModelType.GOOGLE:
api_key = chat_model.ai_model_api.api_key
chat_response = converse_gemini(
compiled_references,
query_to_run,
online_results,
code_results,
meta_log,
model=conversation_config.chat_model,
model=chat_model.name,
api_key=api_key,
completion_func=partial_completion,
conversation_commands=conversation_commands,
max_prompt_size=conversation_config.max_prompt_size,
tokenizer_name=conversation_config.tokenizer,
max_prompt_size=chat_model.max_prompt_size,
tokenizer_name=chat_model.tokenizer,
location_data=location_data,
user_name=user_name,
agent=agent,
@ -1339,7 +1343,7 @@ def generate_chat_response(
tracer=tracer,
)
metadata.update({"chat_model": conversation_config.chat_model})
metadata.update({"chat_model": chat_model.name})
except Exception as e:
logger.error(e, exc_info=True)
@ -1939,13 +1943,13 @@ def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False)
current_notion_config = get_user_notion_config(user)
notion_token = current_notion_config.token if current_notion_config else ""
selected_chat_model_config = ConversationAdapters.get_conversation_config(
selected_chat_model_config = ConversationAdapters.get_chat_model(
user
) or ConversationAdapters.get_default_conversation_config(user)
) or ConversationAdapters.get_default_chat_model(user)
chat_models = ConversationAdapters.get_conversation_processor_options().all()
chat_model_options = list()
for chat_model in chat_models:
chat_model_options.append({"name": chat_model.chat_model, "id": chat_model.id})
chat_model_options.append({"name": chat_model.name, "id": chat_model.id})
selected_paint_model_config = ConversationAdapters.get_user_text_to_image_model_config(user)
paint_model_options = ConversationAdapters.get_text_to_image_model_options().all()

View file

@ -584,13 +584,15 @@ def get_cost_of_chat_message(model_name: str, input_tokens: int = 0, output_toke
return input_cost + output_cost + prev_cost
def get_chat_usage_metrics(model_name: str, input_tokens: int = 0, output_tokens: int = 0, usage: dict = {}):
def get_chat_usage_metrics(
model_name: str, input_tokens: int = 0, output_tokens: int = 0, usage: dict = {}, cost: float = None
):
"""
Get usage metrics for chat message based on input and output tokens
Get usage metrics for chat message based on input and output tokens and cost
"""
prev_usage = usage or {"input_tokens": 0, "output_tokens": 0, "cost": 0.0}
return {
"input_tokens": prev_usage["input_tokens"] + input_tokens,
"output_tokens": prev_usage["output_tokens"] + output_tokens,
"cost": get_cost_of_chat_message(model_name, input_tokens, output_tokens, prev_cost=prev_usage["cost"]),
"cost": cost or get_cost_of_chat_message(model_name, input_tokens, output_tokens, prev_cost=prev_usage["cost"]),
}

View file

@ -7,7 +7,7 @@ import openai
from khoj.database.adapters import ConversationAdapters
from khoj.database.models import (
AiModelApi,
ChatModelOptions,
ChatModel,
KhojUser,
SpeechToTextModelOptions,
TextToImageModelConfig,
@ -63,7 +63,7 @@ def initialization(interactive: bool = True):
# Set up OpenAI's online chat models
openai_configured, openai_provider = _setup_chat_model_provider(
ChatModelOptions.ModelType.OPENAI,
ChatModel.ModelType.OPENAI,
default_chat_models,
default_api_key=openai_api_key,
api_base_url=openai_api_base,
@ -105,7 +105,7 @@ def initialization(interactive: bool = True):
# Set up Google's Gemini online chat models
_setup_chat_model_provider(
ChatModelOptions.ModelType.GOOGLE,
ChatModel.ModelType.GOOGLE,
default_gemini_chat_models,
default_api_key=os.getenv("GEMINI_API_KEY"),
vision_enabled=True,
@ -116,7 +116,7 @@ def initialization(interactive: bool = True):
# Set up Anthropic's online chat models
_setup_chat_model_provider(
ChatModelOptions.ModelType.ANTHROPIC,
ChatModel.ModelType.ANTHROPIC,
default_anthropic_chat_models,
default_api_key=os.getenv("ANTHROPIC_API_KEY"),
vision_enabled=True,
@ -126,7 +126,7 @@ def initialization(interactive: bool = True):
# Set up offline chat models
_setup_chat_model_provider(
ChatModelOptions.ModelType.OFFLINE,
ChatModel.ModelType.OFFLINE,
default_offline_chat_models,
default_api_key=None,
vision_enabled=False,
@ -135,9 +135,9 @@ def initialization(interactive: bool = True):
)
# Explicitly set default chat model
chat_models_configured = ChatModelOptions.objects.count()
chat_models_configured = ChatModel.objects.count()
if chat_models_configured > 0:
default_chat_model_name = ChatModelOptions.objects.first().chat_model
default_chat_model_name = ChatModel.objects.first().name
# If there are multiple chat models, ask the user to choose the default chat model
if chat_models_configured > 1 and interactive:
user_chat_model_name = input(
@ -147,7 +147,7 @@ def initialization(interactive: bool = True):
user_chat_model_name = None
# If the user's choice is valid, set it as the default chat model
if user_chat_model_name and ChatModelOptions.objects.filter(chat_model=user_chat_model_name).exists():
if user_chat_model_name and ChatModel.objects.filter(name=user_chat_model_name).exists():
default_chat_model_name = user_chat_model_name
logger.info("🗣️ Chat model configuration complete")
@ -171,7 +171,7 @@ def initialization(interactive: bool = True):
logger.info(f"🗣️ Offline speech to text model configured to {offline_speech2text_model}")
def _setup_chat_model_provider(
model_type: ChatModelOptions.ModelType,
model_type: ChatModel.ModelType,
default_chat_models: list,
default_api_key: str,
interactive: bool,
@ -204,10 +204,10 @@ def initialization(interactive: bool = True):
ai_model_api = AiModelApi.objects.create(api_key=api_key, name=provider_name, api_base_url=api_base_url)
if interactive:
chat_model_names = input(
user_chat_models = input(
f"Enter the {provider_name} chat models you want to use (default: {','.join(default_chat_models)}): "
)
chat_models = chat_model_names.split(",") if chat_model_names != "" else default_chat_models
chat_models = user_chat_models.split(",") if user_chat_models != "" else default_chat_models
chat_models = [model.strip() for model in chat_models]
else:
chat_models = default_chat_models
@ -218,7 +218,7 @@ def initialization(interactive: bool = True):
vision_enabled = vision_enabled and chat_model in supported_vision_models
chat_model_options = {
"chat_model": chat_model,
"name": chat_model,
"model_type": model_type,
"max_prompt_size": default_max_tokens,
"vision_enabled": vision_enabled,
@ -226,7 +226,7 @@ def initialization(interactive: bool = True):
"ai_model_api": ai_model_api,
}
ChatModelOptions.objects.create(**chat_model_options)
ChatModel.objects.create(**chat_model_options)
logger.info(f"🗣️ {provider_name} chat model configuration complete")
return True, ai_model_api
@ -250,19 +250,19 @@ def initialization(interactive: bool = True):
available_models = [model.id for model in openai_client.models.list()]
# Get existing chat model options for this config
existing_models = ChatModelOptions.objects.filter(
ai_model_api=config, model_type=ChatModelOptions.ModelType.OPENAI
existing_models = ChatModel.objects.filter(
ai_model_api=config, model_type=ChatModel.ModelType.OPENAI
)
# Add new models
for model in available_models:
if not existing_models.filter(chat_model=model).exists():
ChatModelOptions.objects.create(
chat_model=model,
model_type=ChatModelOptions.ModelType.OPENAI,
max_prompt_size=model_to_prompt_size.get(model),
vision_enabled=model in default_openai_chat_models,
tokenizer=model_to_tokenizer.get(model),
for model_name in available_models:
if not existing_models.filter(name=model_name).exists():
ChatModel.objects.create(
name=model_name,
model_type=ChatModel.ModelType.OPENAI,
max_prompt_size=model_to_prompt_size.get(model_name),
vision_enabled=model_name in default_openai_chat_models,
tokenizer=model_to_tokenizer.get(model_name),
ai_model_api=config,
)
@ -284,7 +284,7 @@ def initialization(interactive: bool = True):
except Exception as e:
logger.error(f"🚨 Failed to create admin user: {e}", exc_info=True)
chat_config = ConversationAdapters.get_default_conversation_config()
chat_config = ConversationAdapters.get_default_chat_model()
if admin_user is None and chat_config is None:
while True:
try:

View file

@ -13,7 +13,7 @@ from khoj.configure import (
)
from khoj.database.models import (
Agent,
ChatModelOptions,
ChatModel,
GithubConfig,
GithubRepoConfig,
KhojApiUser,
@ -35,7 +35,7 @@ from khoj.utils.helpers import resolve_absolute_path
from khoj.utils.rawconfig import ContentConfig, ImageSearchConfig, SearchConfig
from tests.helpers import (
AiModelApiFactory,
ChatModelOptionsFactory,
ChatModelFactory,
ProcessLockFactory,
SubscriptionFactory,
UserConversationProcessorConfigFactory,
@ -184,14 +184,14 @@ def api_user4(default_user4):
@pytest.mark.django_db
@pytest.fixture
def default_openai_chat_model_option():
chat_model = ChatModelOptionsFactory(chat_model="gpt-4o-mini", model_type="openai")
chat_model = ChatModelFactory(name="gpt-4o-mini", model_type="openai")
return chat_model
@pytest.mark.django_db
@pytest.fixture
def offline_agent():
chat_model = ChatModelOptionsFactory()
chat_model = ChatModelFactory()
return Agent.objects.create(
name="Accountant",
chat_model=chat_model,
@ -202,7 +202,7 @@ def offline_agent():
@pytest.mark.django_db
@pytest.fixture
def openai_agent():
chat_model = ChatModelOptionsFactory(chat_model="gpt-4o-mini", model_type="openai")
chat_model = ChatModelFactory(name="gpt-4o-mini", model_type="openai")
return Agent.objects.create(
name="Accountant",
chat_model=chat_model,
@ -311,13 +311,13 @@ def chat_client_builder(search_config, user, index_content=True, require_auth=Fa
# Initialize Processor from Config
chat_provider = get_chat_provider()
online_chat_model: ChatModelOptionsFactory = None
if chat_provider == ChatModelOptions.ModelType.OPENAI:
online_chat_model = ChatModelOptionsFactory(chat_model="gpt-4o-mini", model_type="openai")
elif chat_provider == ChatModelOptions.ModelType.GOOGLE:
online_chat_model = ChatModelOptionsFactory(chat_model="gemini-1.5-flash", model_type="google")
elif chat_provider == ChatModelOptions.ModelType.ANTHROPIC:
online_chat_model = ChatModelOptionsFactory(chat_model="claude-3-5-haiku-20241022", model_type="anthropic")
online_chat_model: ChatModelFactory = None
if chat_provider == ChatModel.ModelType.OPENAI:
online_chat_model = ChatModelFactory(name="gpt-4o-mini", model_type="openai")
elif chat_provider == ChatModel.ModelType.GOOGLE:
online_chat_model = ChatModelFactory(name="gemini-1.5-flash", model_type="google")
elif chat_provider == ChatModel.ModelType.ANTHROPIC:
online_chat_model = ChatModelFactory(name="claude-3-5-haiku-20241022", model_type="anthropic")
if online_chat_model:
online_chat_model.ai_model_api = AiModelApiFactory(api_key=get_chat_api_key(chat_provider))
UserConversationProcessorConfigFactory(user=user, setting=online_chat_model)
@ -394,8 +394,8 @@ def client_offline_chat(search_config: SearchConfig, default_user2: KhojUser):
configure_content(default_user2, all_files)
# Initialize Processor from Config
ChatModelOptionsFactory(
chat_model="bartowski/Meta-Llama-3.1-3B-Instruct-GGUF",
ChatModelFactory(
name="bartowski/Meta-Llama-3.1-3B-Instruct-GGUF",
tokenizer=None,
max_prompt_size=None,
model_type="offline",

View file

@ -6,7 +6,7 @@ from django.utils.timezone import make_aware
from khoj.database.models import (
AiModelApi,
ChatModelOptions,
ChatModel,
Conversation,
KhojApiUser,
KhojUser,
@ -18,27 +18,27 @@ from khoj.database.models import (
from khoj.processor.conversation.utils import message_to_log
def get_chat_provider(default: ChatModelOptions.ModelType | None = ChatModelOptions.ModelType.OFFLINE):
def get_chat_provider(default: ChatModel.ModelType | None = ChatModel.ModelType.OFFLINE):
provider = os.getenv("KHOJ_TEST_CHAT_PROVIDER")
if provider and provider in ChatModelOptions.ModelType:
return ChatModelOptions.ModelType(provider)
if provider and provider in ChatModel.ModelType:
return ChatModel.ModelType(provider)
elif os.getenv("OPENAI_API_KEY"):
return ChatModelOptions.ModelType.OPENAI
return ChatModel.ModelType.OPENAI
elif os.getenv("GEMINI_API_KEY"):
return ChatModelOptions.ModelType.GOOGLE
return ChatModel.ModelType.GOOGLE
elif os.getenv("ANTHROPIC_API_KEY"):
return ChatModelOptions.ModelType.ANTHROPIC
return ChatModel.ModelType.ANTHROPIC
else:
return default
def get_chat_api_key(provider: ChatModelOptions.ModelType = None):
def get_chat_api_key(provider: ChatModel.ModelType = None):
provider = provider or get_chat_provider()
if provider == ChatModelOptions.ModelType.OPENAI:
if provider == ChatModel.ModelType.OPENAI:
return os.getenv("OPENAI_API_KEY")
elif provider == ChatModelOptions.ModelType.GOOGLE:
elif provider == ChatModel.ModelType.GOOGLE:
return os.getenv("GEMINI_API_KEY")
elif provider == ChatModelOptions.ModelType.ANTHROPIC:
elif provider == ChatModel.ModelType.ANTHROPIC:
return os.getenv("ANTHROPIC_API_KEY")
else:
return os.getenv("OPENAI_API_KEY") or os.getenv("GEMINI_API_KEY") or os.getenv("ANTHROPIC_API_KEY")
@ -83,13 +83,13 @@ class AiModelApiFactory(factory.django.DjangoModelFactory):
api_key = get_chat_api_key()
class ChatModelOptionsFactory(factory.django.DjangoModelFactory):
class ChatModelFactory(factory.django.DjangoModelFactory):
class Meta:
model = ChatModelOptions
model = ChatModel
max_prompt_size = 20000
tokenizer = None
chat_model = "bartowski/Meta-Llama-3.2-3B-Instruct-GGUF"
name = "bartowski/Meta-Llama-3.2-3B-Instruct-GGUF"
model_type = get_chat_provider()
ai_model_api = factory.LazyAttribute(lambda obj: AiModelApiFactory() if get_chat_api_key() else None)
@ -99,7 +99,7 @@ class UserConversationProcessorConfigFactory(factory.django.DjangoModelFactory):
model = UserConversationConfig
user = factory.SubFactory(UserFactory)
setting = factory.SubFactory(ChatModelOptionsFactory)
setting = factory.SubFactory(ChatModelFactory)
class ConversationFactory(factory.django.DjangoModelFactory):

View file

@ -5,14 +5,14 @@ import pytest
from asgiref.sync import sync_to_async
from khoj.database.adapters import AgentAdapters
from khoj.database.models import Agent, ChatModelOptions, Entry, KhojUser
from khoj.database.models import Agent, ChatModel, Entry, KhojUser
from khoj.routers.api import execute_search
from khoj.utils.helpers import get_absolute_path
from tests.helpers import ChatModelOptionsFactory
from tests.helpers import ChatModelFactory
def test_create_default_agent(default_user: KhojUser):
ChatModelOptionsFactory()
ChatModelFactory()
agent = AgentAdapters.create_default_agent(default_user)
assert agent is not None
@ -24,7 +24,7 @@ def test_create_default_agent(default_user: KhojUser):
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
async def test_create_or_update_agent(default_user: KhojUser, default_openai_chat_model_option: ChatModelOptions):
async def test_create_or_update_agent(default_user: KhojUser, default_openai_chat_model_option: ChatModel):
new_agent = await AgentAdapters.aupdate_agent(
default_user,
"Test Agent",
@ -32,7 +32,7 @@ async def test_create_or_update_agent(default_user: KhojUser, default_openai_cha
Agent.PrivacyLevel.PRIVATE,
"icon",
"color",
default_openai_chat_model_option.chat_model,
default_openai_chat_model_option.name,
[],
[],
[],
@ -46,7 +46,7 @@ async def test_create_or_update_agent(default_user: KhojUser, default_openai_cha
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
async def test_create_or_update_agent_with_knowledge_base(
default_user2: KhojUser, default_openai_chat_model_option: ChatModelOptions, chat_client
default_user2: KhojUser, default_openai_chat_model_option: ChatModel, chat_client
):
full_filename = get_absolute_path("tests/data/markdown/having_kids.markdown")
new_agent = await AgentAdapters.aupdate_agent(
@ -56,7 +56,7 @@ async def test_create_or_update_agent_with_knowledge_base(
Agent.PrivacyLevel.PRIVATE,
"icon",
"color",
default_openai_chat_model_option.chat_model,
default_openai_chat_model_option.name,
[full_filename],
[],
[],
@ -78,7 +78,7 @@ async def test_create_or_update_agent_with_knowledge_base(
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
async def test_create_or_update_agent_with_knowledge_base_and_search(
default_user2: KhojUser, default_openai_chat_model_option: ChatModelOptions, chat_client
default_user2: KhojUser, default_openai_chat_model_option: ChatModel, chat_client
):
full_filename = get_absolute_path("tests/data/markdown/having_kids.markdown")
new_agent = await AgentAdapters.aupdate_agent(
@ -88,7 +88,7 @@ async def test_create_or_update_agent_with_knowledge_base_and_search(
Agent.PrivacyLevel.PRIVATE,
"icon",
"color",
default_openai_chat_model_option.chat_model,
default_openai_chat_model_option.name,
[full_filename],
[],
[],
@ -102,7 +102,7 @@ async def test_create_or_update_agent_with_knowledge_base_and_search(
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
async def test_agent_with_knowledge_base_and_search_not_creator(
default_user2: KhojUser, default_openai_chat_model_option: ChatModelOptions, chat_client, default_user3: KhojUser
default_user2: KhojUser, default_openai_chat_model_option: ChatModel, chat_client, default_user3: KhojUser
):
full_filename = get_absolute_path("tests/data/markdown/having_kids.markdown")
new_agent = await AgentAdapters.aupdate_agent(
@ -112,7 +112,7 @@ async def test_agent_with_knowledge_base_and_search_not_creator(
Agent.PrivacyLevel.PUBLIC,
"icon",
"color",
default_openai_chat_model_option.chat_model,
default_openai_chat_model_option.name,
[full_filename],
[],
[],
@ -126,7 +126,7 @@ async def test_agent_with_knowledge_base_and_search_not_creator(
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
async def test_agent_with_knowledge_base_and_search_not_creator_and_private(
default_user2: KhojUser, default_openai_chat_model_option: ChatModelOptions, chat_client, default_user3: KhojUser
default_user2: KhojUser, default_openai_chat_model_option: ChatModel, chat_client, default_user3: KhojUser
):
full_filename = get_absolute_path("tests/data/markdown/having_kids.markdown")
new_agent = await AgentAdapters.aupdate_agent(
@ -136,7 +136,7 @@ async def test_agent_with_knowledge_base_and_search_not_creator_and_private(
Agent.PrivacyLevel.PRIVATE,
"icon",
"color",
default_openai_chat_model_option.chat_model,
default_openai_chat_model_option.name,
[full_filename],
[],
[],
@ -150,7 +150,7 @@ async def test_agent_with_knowledge_base_and_search_not_creator_and_private(
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
async def test_agent_with_knowledge_base_and_search_not_creator_and_private_accessible_to_none(
default_user2: KhojUser, default_openai_chat_model_option: ChatModelOptions, chat_client
default_user2: KhojUser, default_openai_chat_model_option: ChatModel, chat_client
):
full_filename = get_absolute_path("tests/data/markdown/having_kids.markdown")
new_agent = await AgentAdapters.aupdate_agent(
@ -160,7 +160,7 @@ async def test_agent_with_knowledge_base_and_search_not_creator_and_private_acce
Agent.PrivacyLevel.PRIVATE,
"icon",
"color",
default_openai_chat_model_option.chat_model,
default_openai_chat_model_option.name,
[full_filename],
[],
[],
@ -174,7 +174,7 @@ async def test_agent_with_knowledge_base_and_search_not_creator_and_private_acce
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
async def test_multiple_agents_with_knowledge_base_and_users(
default_user2: KhojUser, default_openai_chat_model_option: ChatModelOptions, chat_client, default_user3: KhojUser
default_user2: KhojUser, default_openai_chat_model_option: ChatModel, chat_client, default_user3: KhojUser
):
full_filename = get_absolute_path("tests/data/markdown/having_kids.markdown")
new_agent = await AgentAdapters.aupdate_agent(
@ -184,7 +184,7 @@ async def test_multiple_agents_with_knowledge_base_and_users(
Agent.PrivacyLevel.PUBLIC,
"icon",
"color",
default_openai_chat_model_option.chat_model,
default_openai_chat_model_option.name,
[full_filename],
[],
[],
@ -198,7 +198,7 @@ async def test_multiple_agents_with_knowledge_base_and_users(
Agent.PrivacyLevel.PUBLIC,
"icon",
"color",
default_openai_chat_model_option.chat_model,
default_openai_chat_model_option.name,
[full_filename2],
[],
[],

View file

@ -2,12 +2,12 @@ from datetime import datetime
import pytest
from khoj.database.models import ChatModelOptions
from khoj.database.models import ChatModel
from khoj.routers.helpers import aget_data_sources_and_output_format
from khoj.utils.helpers import ConversationCommand
from tests.helpers import ConversationFactory, generate_chat_history, get_chat_provider
SKIP_TESTS = get_chat_provider(default=None) != ChatModelOptions.ModelType.OFFLINE
SKIP_TESTS = get_chat_provider(default=None) != ChatModel.ModelType.OFFLINE
pytestmark = pytest.mark.skipif(
SKIP_TESTS,
reason="Disable in CI to avoid long test runs.",

View file

@ -4,12 +4,12 @@ import pytest
from faker import Faker
from freezegun import freeze_time
from khoj.database.models import Agent, ChatModelOptions, Entry, KhojUser
from khoj.database.models import Agent, ChatModel, Entry, KhojUser
from khoj.processor.conversation import prompts
from khoj.processor.conversation.utils import message_to_log
from tests.helpers import ConversationFactory, get_chat_provider
SKIP_TESTS = get_chat_provider(default=None) != ChatModelOptions.ModelType.OFFLINE
SKIP_TESTS = get_chat_provider(default=None) != ChatModel.ModelType.OFFLINE
pytestmark = pytest.mark.skipif(
SKIP_TESTS,
reason="Disable in CI to avoid long test runs.",