khoj/tests/helpers.py
Debanjum 01bc6d35dc
Rename Chat Model Options table to Chat Model as short & readable (#1003)
- Previous was incorrectly plural but was defining only a single model
- Rename chat model table field to name
- Update documentation
- Update references functions and variables to match new name
2024-12-12 11:24:16 -08:00

136 lines
3.8 KiB
Python

import os
from datetime import datetime
import factory
from django.utils.timezone import make_aware
from khoj.database.models import (
AiModelApi,
ChatModel,
Conversation,
KhojApiUser,
KhojUser,
ProcessLock,
SearchModelConfig,
Subscription,
UserConversationConfig,
)
from khoj.processor.conversation.utils import message_to_log
def get_chat_provider(default: ChatModel.ModelType | None = ChatModel.ModelType.OFFLINE):
provider = os.getenv("KHOJ_TEST_CHAT_PROVIDER")
if provider and provider in ChatModel.ModelType:
return ChatModel.ModelType(provider)
elif os.getenv("OPENAI_API_KEY"):
return ChatModel.ModelType.OPENAI
elif os.getenv("GEMINI_API_KEY"):
return ChatModel.ModelType.GOOGLE
elif os.getenv("ANTHROPIC_API_KEY"):
return ChatModel.ModelType.ANTHROPIC
else:
return default
def get_chat_api_key(provider: ChatModel.ModelType = None):
provider = provider or get_chat_provider()
if provider == ChatModel.ModelType.OPENAI:
return os.getenv("OPENAI_API_KEY")
elif provider == ChatModel.ModelType.GOOGLE:
return os.getenv("GEMINI_API_KEY")
elif provider == ChatModel.ModelType.ANTHROPIC:
return os.getenv("ANTHROPIC_API_KEY")
else:
return os.getenv("OPENAI_API_KEY") or os.getenv("GEMINI_API_KEY") or os.getenv("ANTHROPIC_API_KEY")
def generate_chat_history(message_list):
# Generate conversation logs
conversation_log = {"chat": []}
for user_message, chat_response, context in message_list:
message_to_log(
user_message,
chat_response,
{"context": context, "intent": {"query": user_message, "inferred-queries": f'["{user_message}"]'}},
conversation_log=conversation_log.get("chat", []),
)
return conversation_log
class UserFactory(factory.django.DjangoModelFactory):
class Meta:
model = KhojUser
username = factory.Faker("name")
email = factory.Faker("email")
password = factory.Faker("password")
uuid = factory.Faker("uuid4")
class ApiUserFactory(factory.django.DjangoModelFactory):
class Meta:
model = KhojApiUser
user = None
name = factory.Faker("name")
token = factory.Faker("password")
class AiModelApiFactory(factory.django.DjangoModelFactory):
class Meta:
model = AiModelApi
api_key = get_chat_api_key()
class ChatModelFactory(factory.django.DjangoModelFactory):
class Meta:
model = ChatModel
max_prompt_size = 20000
tokenizer = None
name = "bartowski/Meta-Llama-3.2-3B-Instruct-GGUF"
model_type = get_chat_provider()
ai_model_api = factory.LazyAttribute(lambda obj: AiModelApiFactory() if get_chat_api_key() else None)
class UserConversationProcessorConfigFactory(factory.django.DjangoModelFactory):
class Meta:
model = UserConversationConfig
user = factory.SubFactory(UserFactory)
setting = factory.SubFactory(ChatModelFactory)
class ConversationFactory(factory.django.DjangoModelFactory):
class Meta:
model = Conversation
user = factory.SubFactory(UserFactory)
class SearchModelFactory(factory.django.DjangoModelFactory):
class Meta:
model = SearchModelConfig
name = "default"
model_type = "text"
bi_encoder = "thenlper/gte-small"
cross_encoder = "mixedbread-ai/mxbai-rerank-xsmall-v1"
class SubscriptionFactory(factory.django.DjangoModelFactory):
class Meta:
model = Subscription
user = factory.SubFactory(UserFactory)
type = Subscription.Type.STANDARD
is_recurring = False
renewal_date = make_aware(datetime.strptime("2100-04-01", "%Y-%m-%d"))
class ProcessLockFactory(factory.django.DjangoModelFactory):
class Meta:
model = ProcessLock
name = "test_lock"