Get rid of enable flag for the offline chat processor config

- Default, assume that offline chat is enabled if there is an offline chat model option configured
This commit is contained in:
sabaimran 2024-04-23 23:08:29 +05:30
parent ac474fce38
commit 60658a8037
13 changed files with 33 additions and 67 deletions

View file

@ -28,7 +28,6 @@ from khoj.database.models import (
KhojApiUser, KhojApiUser,
KhojUser, KhojUser,
NotionConfig, NotionConfig,
OfflineChatProcessorConversationConfig,
OpenAIProcessorConversationConfig, OpenAIProcessorConversationConfig,
ProcessLock, ProcessLock,
ReflectiveQuestion, ReflectiveQuestion,
@ -628,18 +627,6 @@ class ConversationAdapters:
async def aget_openai_conversation_config(): async def aget_openai_conversation_config():
return await OpenAIProcessorConversationConfig.objects.filter().afirst() return await OpenAIProcessorConversationConfig.objects.filter().afirst()
@staticmethod
def get_offline_chat_conversation_config():
return OfflineChatProcessorConversationConfig.objects.filter().first()
@staticmethod
async def aget_offline_chat_conversation_config():
return await OfflineChatProcessorConversationConfig.objects.filter().afirst()
@staticmethod
def has_valid_offline_conversation_config():
return OfflineChatProcessorConversationConfig.objects.filter(enabled=True).exists()
@staticmethod @staticmethod
def has_valid_openai_conversation_config(): def has_valid_openai_conversation_config():
return OpenAIProcessorConversationConfig.objects.filter().exists() return OpenAIProcessorConversationConfig.objects.filter().exists()
@ -710,14 +697,6 @@ class ConversationAdapters:
user_conversation_config.setting = new_config user_conversation_config.setting = new_config
user_conversation_config.save() user_conversation_config.save()
@staticmethod
def has_offline_chat():
return OfflineChatProcessorConversationConfig.objects.filter(enabled=True).exists()
@staticmethod
async def ahas_offline_chat():
return await OfflineChatProcessorConversationConfig.objects.filter(enabled=True).aexists()
@staticmethod @staticmethod
async def get_default_offline_llm(): async def get_default_offline_llm():
return await ChatModelOptions.objects.filter(model_type="offline").afirst() return await ChatModelOptions.objects.filter(model_type="offline").afirst()
@ -765,8 +744,6 @@ class ConversationAdapters:
@staticmethod @staticmethod
def get_valid_conversation_config(user: KhojUser, conversation: Conversation): def get_valid_conversation_config(user: KhojUser, conversation: Conversation):
offline_chat_config = ConversationAdapters.get_offline_chat_conversation_config()
if conversation.agent and conversation.agent.chat_model: if conversation.agent and conversation.agent.chat_model:
conversation_config = conversation.agent.chat_model conversation_config = conversation.agent.chat_model
else: else:
@ -775,7 +752,7 @@ class ConversationAdapters:
if conversation_config is None: if conversation_config is None:
conversation_config = ConversationAdapters.get_default_conversation_config() conversation_config = ConversationAdapters.get_default_conversation_config()
if offline_chat_config and offline_chat_config.enabled and conversation_config.model_type == "offline": if conversation_config.model_type == "offline":
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None: if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
chat_model = conversation_config.chat_model chat_model = conversation_config.chat_model
max_tokens = conversation_config.max_prompt_size max_tokens = conversation_config.max_prompt_size

View file

@ -14,7 +14,6 @@ from khoj.database.models import (
GithubConfig, GithubConfig,
KhojUser, KhojUser,
NotionConfig, NotionConfig,
OfflineChatProcessorConversationConfig,
OpenAIProcessorConversationConfig, OpenAIProcessorConversationConfig,
ReflectiveQuestion, ReflectiveQuestion,
SearchModelConfig, SearchModelConfig,
@ -47,7 +46,6 @@ admin.site.register(KhojUser, KhojUserAdmin)
admin.site.register(ChatModelOptions) admin.site.register(ChatModelOptions)
admin.site.register(SpeechToTextModelOptions) admin.site.register(SpeechToTextModelOptions)
admin.site.register(OpenAIProcessorConversationConfig) admin.site.register(OpenAIProcessorConversationConfig)
admin.site.register(OfflineChatProcessorConversationConfig)
admin.site.register(SearchModelConfig) admin.site.register(SearchModelConfig)
admin.site.register(Subscription) admin.site.register(Subscription)
admin.site.register(ReflectiveQuestion) admin.site.register(ReflectiveQuestion)

View file

@ -0,0 +1,15 @@
# Generated by Django 4.2.10 on 2024-04-23 17:35
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
("database", "0035_processlock"),
]
operations = [
migrations.DeleteModel(
name="OfflineChatProcessorConversationConfig",
),
]

View file

@ -201,10 +201,6 @@ class OpenAIProcessorConversationConfig(BaseModel):
api_key = models.CharField(max_length=200) api_key = models.CharField(max_length=200)
class OfflineChatProcessorConversationConfig(BaseModel):
enabled = models.BooleanField(default=False)
class SpeechToTextModelOptions(BaseModel): class SpeechToTextModelOptions(BaseModel):
class ModelType(models.TextChoices): class ModelType(models.TextChoices):
OPENAI = "openai" OPENAI = "openai"

View file

@ -62,7 +62,6 @@ from packaging import version
from khoj.database.models import ( from khoj.database.models import (
ChatModelOptions, ChatModelOptions,
OfflineChatProcessorConversationConfig,
OpenAIProcessorConversationConfig, OpenAIProcessorConversationConfig,
SearchModelConfig, SearchModelConfig,
) )
@ -103,9 +102,6 @@ def migrate_server_pg(args):
if "offline-chat" in raw_config["processor"]["conversation"]: if "offline-chat" in raw_config["processor"]["conversation"]:
offline_chat = raw_config["processor"]["conversation"]["offline-chat"] offline_chat = raw_config["processor"]["conversation"]["offline-chat"]
OfflineChatProcessorConversationConfig.objects.create(
enabled=offline_chat.get("enable-offline-chat"),
)
ChatModelOptions.objects.create( ChatModelOptions.objects.create(
chat_model=offline_chat.get("chat-model"), chat_model=offline_chat.get("chat-model"),
tokenizer=processor_conversation.get("tokenizer"), tokenizer=processor_conversation.get("tokenizer"),

View file

@ -67,5 +67,6 @@ def load_model_from_cache(repo_id: str, filename: str, repo_type="models"):
def infer_max_tokens(model_context_window: int, configured_max_tokens=math.inf) -> int: def infer_max_tokens(model_context_window: int, configured_max_tokens=math.inf) -> int:
"""Infer max prompt size based on device memory and max context window supported by the model""" """Infer max prompt size based on device memory and max context window supported by the model"""
configured_max_tokens = math.inf if configured_max_tokens is None else configured_max_tokens
vram_based_n_ctx = int(get_device_memory() / 2e6) # based on heuristic vram_based_n_ctx = int(get_device_memory() / 2e6) # based on heuristic
return min(configured_max_tokens, vram_based_n_ctx, model_context_window) return min(configured_max_tokens, vram_based_n_ctx, model_context_window)

View file

@ -303,15 +303,10 @@ async def extract_references_and_questions(
# Infer search queries from user message # Infer search queries from user message
with timer("Extracting search queries took", logger): with timer("Extracting search queries took", logger):
# If we've reached here, either the user has enabled offline chat or the openai model is enabled. # If we've reached here, either the user has enabled offline chat or the openai model is enabled.
offline_chat_config = await ConversationAdapters.aget_offline_chat_conversation_config()
conversation_config = await ConversationAdapters.aget_conversation_config(user) conversation_config = await ConversationAdapters.aget_conversation_config(user)
if conversation_config is None: if conversation_config is None:
conversation_config = await ConversationAdapters.aget_default_conversation_config() conversation_config = await ConversationAdapters.aget_default_conversation_config()
if ( if conversation_config.model_type == ChatModelOptions.ModelType.OFFLINE:
offline_chat_config
and offline_chat_config.enabled
and conversation_config.model_type == ChatModelOptions.ModelType.OFFLINE
):
using_offline_chat = True using_offline_chat = True
default_offline_llm = await ConversationAdapters.get_default_offline_llm() default_offline_llm = await ConversationAdapters.get_default_offline_llm()
chat_model = default_offline_llm.chat_model chat_model = default_offline_llm.chat_model

View file

@ -65,23 +65,20 @@ executor = ThreadPoolExecutor(max_workers=1)
def validate_conversation_config(): def validate_conversation_config():
if ( default_config = ConversationAdapters.get_default_conversation_config()
ConversationAdapters.has_valid_offline_conversation_config()
or ConversationAdapters.has_valid_openai_conversation_config()
):
if ConversationAdapters.get_default_conversation_config() is None:
raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.")
return
raise HTTPException(status_code=500, detail="Set your OpenAI API key or enable Local LLM via Khoj settings.") if default_config is None:
raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.")
if default_config.model_type == "openai" and not ConversationAdapters.has_valid_openai_conversation_config():
raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.")
async def is_ready_to_chat(user: KhojUser): async def is_ready_to_chat(user: KhojUser):
has_offline_config = await ConversationAdapters.ahas_offline_chat()
has_openai_config = await ConversationAdapters.has_openai_chat() has_openai_config = await ConversationAdapters.has_openai_chat()
user_conversation_config = await ConversationAdapters.aget_user_conversation_config(user) user_conversation_config = await ConversationAdapters.aget_user_conversation_config(user)
if has_offline_config and user_conversation_config and user_conversation_config.model_type == "offline": if user_conversation_config and user_conversation_config.model_type == "offline":
chat_model = user_conversation_config.chat_model chat_model = user_conversation_config.chat_model
max_tokens = user_conversation_config.max_prompt_size max_tokens = user_conversation_config.max_prompt_size
if state.offline_chat_processor_config is None: if state.offline_chat_processor_config is None:
@ -89,9 +86,7 @@ async def is_ready_to_chat(user: KhojUser):
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens) state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
return True return True
ready = has_openai_config or has_offline_config if not has_openai_config:
if not ready:
raise HTTPException(status_code=500, detail="Set your OpenAI API key or enable Local LLM via Khoj settings.") raise HTTPException(status_code=500, detail="Set your OpenAI API key or enable Local LLM via Khoj settings.")

View file

@ -5,7 +5,6 @@ from khoj.database.adapters import ConversationAdapters
from khoj.database.models import ( from khoj.database.models import (
ChatModelOptions, ChatModelOptions,
KhojUser, KhojUser,
OfflineChatProcessorConversationConfig,
OpenAIProcessorConversationConfig, OpenAIProcessorConversationConfig,
SpeechToTextModelOptions, SpeechToTextModelOptions,
TextToImageModelConfig, TextToImageModelConfig,
@ -35,7 +34,6 @@ def initialization():
use_offline_model = input("Use offline chat model? (y/n): ") use_offline_model = input("Use offline chat model? (y/n): ")
if use_offline_model == "y": if use_offline_model == "y":
logger.info("🗣️ Setting up offline chat model") logger.info("🗣️ Setting up offline chat model")
OfflineChatProcessorConversationConfig.objects.create(enabled=True)
offline_chat_model = input( offline_chat_model = input(
f"Enter the offline chat model you want to use. See HuggingFace for available GGUF models (default: {default_offline_chat_model}): " f"Enter the offline chat model you want to use. See HuggingFace for available GGUF models (default: {default_offline_chat_model}): "

View file

@ -81,7 +81,6 @@ class OpenAIProcessorConfig(ConfigBase):
class OfflineChatProcessorConfig(ConfigBase): class OfflineChatProcessorConfig(ConfigBase):
enable_offline_chat: Optional[bool] = False
chat_model: Optional[str] = "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF" chat_model: Optional[str] = "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF"

View file

@ -33,7 +33,6 @@ from khoj.utils.helpers import resolve_absolute_path
from khoj.utils.rawconfig import ContentConfig, ImageSearchConfig, SearchConfig from khoj.utils.rawconfig import ContentConfig, ImageSearchConfig, SearchConfig
from tests.helpers import ( from tests.helpers import (
ChatModelOptionsFactory, ChatModelOptionsFactory,
OfflineChatProcessorConversationConfigFactory,
OpenAIProcessorConversationConfigFactory, OpenAIProcessorConversationConfigFactory,
ProcessLockFactory, ProcessLockFactory,
SubscriptionFactory, SubscriptionFactory,
@ -377,7 +376,12 @@ def client_offline_chat(search_config: SearchConfig, default_user2: KhojUser):
configure_content(all_files, user=default_user2) configure_content(all_files, user=default_user2)
# Initialize Processor from Config # Initialize Processor from Config
OfflineChatProcessorConversationConfigFactory(enabled=True) ChatModelOptionsFactory(
chat_model="NousResearch/Hermes-2-Pro-Mistral-7B-GGUF",
tokenizer=None,
max_prompt_size=None,
model_type="offline",
)
UserConversationProcessorConfigFactory(user=default_user2) UserConversationProcessorConfigFactory(user=default_user2)
state.anonymous_mode = True state.anonymous_mode = True

View file

@ -9,7 +9,6 @@ from khoj.database.models import (
Conversation, Conversation,
KhojApiUser, KhojApiUser,
KhojUser, KhojUser,
OfflineChatProcessorConversationConfig,
OpenAIProcessorConversationConfig, OpenAIProcessorConversationConfig,
ProcessLock, ProcessLock,
SearchModelConfig, SearchModelConfig,
@ -55,13 +54,6 @@ class UserConversationProcessorConfigFactory(factory.django.DjangoModelFactory):
setting = factory.SubFactory(ChatModelOptionsFactory) setting = factory.SubFactory(ChatModelOptionsFactory)
class OfflineChatProcessorConversationConfigFactory(factory.django.DjangoModelFactory):
class Meta:
model = OfflineChatProcessorConversationConfig
enabled = True
class OpenAIProcessorConversationConfigFactory(factory.django.DjangoModelFactory): class OpenAIProcessorConversationConfigFactory(factory.django.DjangoModelFactory):
class Meta: class Meta:
model = OpenAIProcessorConversationConfig model = OpenAIProcessorConversationConfig

View file

@ -24,7 +24,7 @@ from khoj.utils.constants import default_offline_chat_model
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def loaded_model(): def loaded_model():
return download_model(default_offline_chat_model) return download_model(default_offline_chat_model, max_tokens=5000)
freezegun.configure(extend_ignore_list=["transformers"]) freezegun.configure(extend_ignore_list=["transformers"])