mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-27 17:35:07 +01:00
Get rid of enable flag for the offline chat processor config
- Default, assume that offline chat is enabled if there is an offline chat model option configured
This commit is contained in:
parent
ac474fce38
commit
60658a8037
13 changed files with 33 additions and 67 deletions
|
@ -28,7 +28,6 @@ from khoj.database.models import (
|
||||||
KhojApiUser,
|
KhojApiUser,
|
||||||
KhojUser,
|
KhojUser,
|
||||||
NotionConfig,
|
NotionConfig,
|
||||||
OfflineChatProcessorConversationConfig,
|
|
||||||
OpenAIProcessorConversationConfig,
|
OpenAIProcessorConversationConfig,
|
||||||
ProcessLock,
|
ProcessLock,
|
||||||
ReflectiveQuestion,
|
ReflectiveQuestion,
|
||||||
|
@ -628,18 +627,6 @@ class ConversationAdapters:
|
||||||
async def aget_openai_conversation_config():
|
async def aget_openai_conversation_config():
|
||||||
return await OpenAIProcessorConversationConfig.objects.filter().afirst()
|
return await OpenAIProcessorConversationConfig.objects.filter().afirst()
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_offline_chat_conversation_config():
|
|
||||||
return OfflineChatProcessorConversationConfig.objects.filter().first()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def aget_offline_chat_conversation_config():
|
|
||||||
return await OfflineChatProcessorConversationConfig.objects.filter().afirst()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def has_valid_offline_conversation_config():
|
|
||||||
return OfflineChatProcessorConversationConfig.objects.filter(enabled=True).exists()
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def has_valid_openai_conversation_config():
|
def has_valid_openai_conversation_config():
|
||||||
return OpenAIProcessorConversationConfig.objects.filter().exists()
|
return OpenAIProcessorConversationConfig.objects.filter().exists()
|
||||||
|
@ -710,14 +697,6 @@ class ConversationAdapters:
|
||||||
user_conversation_config.setting = new_config
|
user_conversation_config.setting = new_config
|
||||||
user_conversation_config.save()
|
user_conversation_config.save()
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def has_offline_chat():
|
|
||||||
return OfflineChatProcessorConversationConfig.objects.filter(enabled=True).exists()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def ahas_offline_chat():
|
|
||||||
return await OfflineChatProcessorConversationConfig.objects.filter(enabled=True).aexists()
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def get_default_offline_llm():
|
async def get_default_offline_llm():
|
||||||
return await ChatModelOptions.objects.filter(model_type="offline").afirst()
|
return await ChatModelOptions.objects.filter(model_type="offline").afirst()
|
||||||
|
@ -765,8 +744,6 @@ class ConversationAdapters:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_valid_conversation_config(user: KhojUser, conversation: Conversation):
|
def get_valid_conversation_config(user: KhojUser, conversation: Conversation):
|
||||||
offline_chat_config = ConversationAdapters.get_offline_chat_conversation_config()
|
|
||||||
|
|
||||||
if conversation.agent and conversation.agent.chat_model:
|
if conversation.agent and conversation.agent.chat_model:
|
||||||
conversation_config = conversation.agent.chat_model
|
conversation_config = conversation.agent.chat_model
|
||||||
else:
|
else:
|
||||||
|
@ -775,7 +752,7 @@ class ConversationAdapters:
|
||||||
if conversation_config is None:
|
if conversation_config is None:
|
||||||
conversation_config = ConversationAdapters.get_default_conversation_config()
|
conversation_config = ConversationAdapters.get_default_conversation_config()
|
||||||
|
|
||||||
if offline_chat_config and offline_chat_config.enabled and conversation_config.model_type == "offline":
|
if conversation_config.model_type == "offline":
|
||||||
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
|
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
|
||||||
chat_model = conversation_config.chat_model
|
chat_model = conversation_config.chat_model
|
||||||
max_tokens = conversation_config.max_prompt_size
|
max_tokens = conversation_config.max_prompt_size
|
||||||
|
|
|
@ -14,7 +14,6 @@ from khoj.database.models import (
|
||||||
GithubConfig,
|
GithubConfig,
|
||||||
KhojUser,
|
KhojUser,
|
||||||
NotionConfig,
|
NotionConfig,
|
||||||
OfflineChatProcessorConversationConfig,
|
|
||||||
OpenAIProcessorConversationConfig,
|
OpenAIProcessorConversationConfig,
|
||||||
ReflectiveQuestion,
|
ReflectiveQuestion,
|
||||||
SearchModelConfig,
|
SearchModelConfig,
|
||||||
|
@ -47,7 +46,6 @@ admin.site.register(KhojUser, KhojUserAdmin)
|
||||||
admin.site.register(ChatModelOptions)
|
admin.site.register(ChatModelOptions)
|
||||||
admin.site.register(SpeechToTextModelOptions)
|
admin.site.register(SpeechToTextModelOptions)
|
||||||
admin.site.register(OpenAIProcessorConversationConfig)
|
admin.site.register(OpenAIProcessorConversationConfig)
|
||||||
admin.site.register(OfflineChatProcessorConversationConfig)
|
|
||||||
admin.site.register(SearchModelConfig)
|
admin.site.register(SearchModelConfig)
|
||||||
admin.site.register(Subscription)
|
admin.site.register(Subscription)
|
||||||
admin.site.register(ReflectiveQuestion)
|
admin.site.register(ReflectiveQuestion)
|
||||||
|
|
|
@ -0,0 +1,15 @@
|
||||||
|
# Generated by Django 4.2.10 on 2024-04-23 17:35
|
||||||
|
|
||||||
|
from django.db import migrations
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
dependencies = [
|
||||||
|
("database", "0035_processlock"),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.DeleteModel(
|
||||||
|
name="OfflineChatProcessorConversationConfig",
|
||||||
|
),
|
||||||
|
]
|
|
@ -201,10 +201,6 @@ class OpenAIProcessorConversationConfig(BaseModel):
|
||||||
api_key = models.CharField(max_length=200)
|
api_key = models.CharField(max_length=200)
|
||||||
|
|
||||||
|
|
||||||
class OfflineChatProcessorConversationConfig(BaseModel):
|
|
||||||
enabled = models.BooleanField(default=False)
|
|
||||||
|
|
||||||
|
|
||||||
class SpeechToTextModelOptions(BaseModel):
|
class SpeechToTextModelOptions(BaseModel):
|
||||||
class ModelType(models.TextChoices):
|
class ModelType(models.TextChoices):
|
||||||
OPENAI = "openai"
|
OPENAI = "openai"
|
||||||
|
|
|
@ -62,7 +62,6 @@ from packaging import version
|
||||||
|
|
||||||
from khoj.database.models import (
|
from khoj.database.models import (
|
||||||
ChatModelOptions,
|
ChatModelOptions,
|
||||||
OfflineChatProcessorConversationConfig,
|
|
||||||
OpenAIProcessorConversationConfig,
|
OpenAIProcessorConversationConfig,
|
||||||
SearchModelConfig,
|
SearchModelConfig,
|
||||||
)
|
)
|
||||||
|
@ -103,9 +102,6 @@ def migrate_server_pg(args):
|
||||||
|
|
||||||
if "offline-chat" in raw_config["processor"]["conversation"]:
|
if "offline-chat" in raw_config["processor"]["conversation"]:
|
||||||
offline_chat = raw_config["processor"]["conversation"]["offline-chat"]
|
offline_chat = raw_config["processor"]["conversation"]["offline-chat"]
|
||||||
OfflineChatProcessorConversationConfig.objects.create(
|
|
||||||
enabled=offline_chat.get("enable-offline-chat"),
|
|
||||||
)
|
|
||||||
ChatModelOptions.objects.create(
|
ChatModelOptions.objects.create(
|
||||||
chat_model=offline_chat.get("chat-model"),
|
chat_model=offline_chat.get("chat-model"),
|
||||||
tokenizer=processor_conversation.get("tokenizer"),
|
tokenizer=processor_conversation.get("tokenizer"),
|
||||||
|
|
|
@ -67,5 +67,6 @@ def load_model_from_cache(repo_id: str, filename: str, repo_type="models"):
|
||||||
|
|
||||||
def infer_max_tokens(model_context_window: int, configured_max_tokens=math.inf) -> int:
|
def infer_max_tokens(model_context_window: int, configured_max_tokens=math.inf) -> int:
|
||||||
"""Infer max prompt size based on device memory and max context window supported by the model"""
|
"""Infer max prompt size based on device memory and max context window supported by the model"""
|
||||||
|
configured_max_tokens = math.inf if configured_max_tokens is None else configured_max_tokens
|
||||||
vram_based_n_ctx = int(get_device_memory() / 2e6) # based on heuristic
|
vram_based_n_ctx = int(get_device_memory() / 2e6) # based on heuristic
|
||||||
return min(configured_max_tokens, vram_based_n_ctx, model_context_window)
|
return min(configured_max_tokens, vram_based_n_ctx, model_context_window)
|
||||||
|
|
|
@ -303,15 +303,10 @@ async def extract_references_and_questions(
|
||||||
# Infer search queries from user message
|
# Infer search queries from user message
|
||||||
with timer("Extracting search queries took", logger):
|
with timer("Extracting search queries took", logger):
|
||||||
# If we've reached here, either the user has enabled offline chat or the openai model is enabled.
|
# If we've reached here, either the user has enabled offline chat or the openai model is enabled.
|
||||||
offline_chat_config = await ConversationAdapters.aget_offline_chat_conversation_config()
|
|
||||||
conversation_config = await ConversationAdapters.aget_conversation_config(user)
|
conversation_config = await ConversationAdapters.aget_conversation_config(user)
|
||||||
if conversation_config is None:
|
if conversation_config is None:
|
||||||
conversation_config = await ConversationAdapters.aget_default_conversation_config()
|
conversation_config = await ConversationAdapters.aget_default_conversation_config()
|
||||||
if (
|
if conversation_config.model_type == ChatModelOptions.ModelType.OFFLINE:
|
||||||
offline_chat_config
|
|
||||||
and offline_chat_config.enabled
|
|
||||||
and conversation_config.model_type == ChatModelOptions.ModelType.OFFLINE
|
|
||||||
):
|
|
||||||
using_offline_chat = True
|
using_offline_chat = True
|
||||||
default_offline_llm = await ConversationAdapters.get_default_offline_llm()
|
default_offline_llm = await ConversationAdapters.get_default_offline_llm()
|
||||||
chat_model = default_offline_llm.chat_model
|
chat_model = default_offline_llm.chat_model
|
||||||
|
|
|
@ -65,23 +65,20 @@ executor = ThreadPoolExecutor(max_workers=1)
|
||||||
|
|
||||||
|
|
||||||
def validate_conversation_config():
|
def validate_conversation_config():
|
||||||
if (
|
default_config = ConversationAdapters.get_default_conversation_config()
|
||||||
ConversationAdapters.has_valid_offline_conversation_config()
|
|
||||||
or ConversationAdapters.has_valid_openai_conversation_config()
|
|
||||||
):
|
|
||||||
if ConversationAdapters.get_default_conversation_config() is None:
|
|
||||||
raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.")
|
|
||||||
return
|
|
||||||
|
|
||||||
raise HTTPException(status_code=500, detail="Set your OpenAI API key or enable Local LLM via Khoj settings.")
|
if default_config is None:
|
||||||
|
raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.")
|
||||||
|
|
||||||
|
if default_config.model_type == "openai" and not ConversationAdapters.has_valid_openai_conversation_config():
|
||||||
|
raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.")
|
||||||
|
|
||||||
|
|
||||||
async def is_ready_to_chat(user: KhojUser):
|
async def is_ready_to_chat(user: KhojUser):
|
||||||
has_offline_config = await ConversationAdapters.ahas_offline_chat()
|
|
||||||
has_openai_config = await ConversationAdapters.has_openai_chat()
|
has_openai_config = await ConversationAdapters.has_openai_chat()
|
||||||
user_conversation_config = await ConversationAdapters.aget_user_conversation_config(user)
|
user_conversation_config = await ConversationAdapters.aget_user_conversation_config(user)
|
||||||
|
|
||||||
if has_offline_config and user_conversation_config and user_conversation_config.model_type == "offline":
|
if user_conversation_config and user_conversation_config.model_type == "offline":
|
||||||
chat_model = user_conversation_config.chat_model
|
chat_model = user_conversation_config.chat_model
|
||||||
max_tokens = user_conversation_config.max_prompt_size
|
max_tokens = user_conversation_config.max_prompt_size
|
||||||
if state.offline_chat_processor_config is None:
|
if state.offline_chat_processor_config is None:
|
||||||
|
@ -89,9 +86,7 @@ async def is_ready_to_chat(user: KhojUser):
|
||||||
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
|
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
ready = has_openai_config or has_offline_config
|
if not has_openai_config:
|
||||||
|
|
||||||
if not ready:
|
|
||||||
raise HTTPException(status_code=500, detail="Set your OpenAI API key or enable Local LLM via Khoj settings.")
|
raise HTTPException(status_code=500, detail="Set your OpenAI API key or enable Local LLM via Khoj settings.")
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,6 @@ from khoj.database.adapters import ConversationAdapters
|
||||||
from khoj.database.models import (
|
from khoj.database.models import (
|
||||||
ChatModelOptions,
|
ChatModelOptions,
|
||||||
KhojUser,
|
KhojUser,
|
||||||
OfflineChatProcessorConversationConfig,
|
|
||||||
OpenAIProcessorConversationConfig,
|
OpenAIProcessorConversationConfig,
|
||||||
SpeechToTextModelOptions,
|
SpeechToTextModelOptions,
|
||||||
TextToImageModelConfig,
|
TextToImageModelConfig,
|
||||||
|
@ -35,7 +34,6 @@ def initialization():
|
||||||
use_offline_model = input("Use offline chat model? (y/n): ")
|
use_offline_model = input("Use offline chat model? (y/n): ")
|
||||||
if use_offline_model == "y":
|
if use_offline_model == "y":
|
||||||
logger.info("🗣️ Setting up offline chat model")
|
logger.info("🗣️ Setting up offline chat model")
|
||||||
OfflineChatProcessorConversationConfig.objects.create(enabled=True)
|
|
||||||
|
|
||||||
offline_chat_model = input(
|
offline_chat_model = input(
|
||||||
f"Enter the offline chat model you want to use. See HuggingFace for available GGUF models (default: {default_offline_chat_model}): "
|
f"Enter the offline chat model you want to use. See HuggingFace for available GGUF models (default: {default_offline_chat_model}): "
|
||||||
|
|
|
@ -81,7 +81,6 @@ class OpenAIProcessorConfig(ConfigBase):
|
||||||
|
|
||||||
|
|
||||||
class OfflineChatProcessorConfig(ConfigBase):
|
class OfflineChatProcessorConfig(ConfigBase):
|
||||||
enable_offline_chat: Optional[bool] = False
|
|
||||||
chat_model: Optional[str] = "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF"
|
chat_model: Optional[str] = "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF"
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -33,7 +33,6 @@ from khoj.utils.helpers import resolve_absolute_path
|
||||||
from khoj.utils.rawconfig import ContentConfig, ImageSearchConfig, SearchConfig
|
from khoj.utils.rawconfig import ContentConfig, ImageSearchConfig, SearchConfig
|
||||||
from tests.helpers import (
|
from tests.helpers import (
|
||||||
ChatModelOptionsFactory,
|
ChatModelOptionsFactory,
|
||||||
OfflineChatProcessorConversationConfigFactory,
|
|
||||||
OpenAIProcessorConversationConfigFactory,
|
OpenAIProcessorConversationConfigFactory,
|
||||||
ProcessLockFactory,
|
ProcessLockFactory,
|
||||||
SubscriptionFactory,
|
SubscriptionFactory,
|
||||||
|
@ -377,7 +376,12 @@ def client_offline_chat(search_config: SearchConfig, default_user2: KhojUser):
|
||||||
configure_content(all_files, user=default_user2)
|
configure_content(all_files, user=default_user2)
|
||||||
|
|
||||||
# Initialize Processor from Config
|
# Initialize Processor from Config
|
||||||
OfflineChatProcessorConversationConfigFactory(enabled=True)
|
ChatModelOptionsFactory(
|
||||||
|
chat_model="NousResearch/Hermes-2-Pro-Mistral-7B-GGUF",
|
||||||
|
tokenizer=None,
|
||||||
|
max_prompt_size=None,
|
||||||
|
model_type="offline",
|
||||||
|
)
|
||||||
UserConversationProcessorConfigFactory(user=default_user2)
|
UserConversationProcessorConfigFactory(user=default_user2)
|
||||||
|
|
||||||
state.anonymous_mode = True
|
state.anonymous_mode = True
|
||||||
|
|
|
@ -9,7 +9,6 @@ from khoj.database.models import (
|
||||||
Conversation,
|
Conversation,
|
||||||
KhojApiUser,
|
KhojApiUser,
|
||||||
KhojUser,
|
KhojUser,
|
||||||
OfflineChatProcessorConversationConfig,
|
|
||||||
OpenAIProcessorConversationConfig,
|
OpenAIProcessorConversationConfig,
|
||||||
ProcessLock,
|
ProcessLock,
|
||||||
SearchModelConfig,
|
SearchModelConfig,
|
||||||
|
@ -55,13 +54,6 @@ class UserConversationProcessorConfigFactory(factory.django.DjangoModelFactory):
|
||||||
setting = factory.SubFactory(ChatModelOptionsFactory)
|
setting = factory.SubFactory(ChatModelOptionsFactory)
|
||||||
|
|
||||||
|
|
||||||
class OfflineChatProcessorConversationConfigFactory(factory.django.DjangoModelFactory):
|
|
||||||
class Meta:
|
|
||||||
model = OfflineChatProcessorConversationConfig
|
|
||||||
|
|
||||||
enabled = True
|
|
||||||
|
|
||||||
|
|
||||||
class OpenAIProcessorConversationConfigFactory(factory.django.DjangoModelFactory):
|
class OpenAIProcessorConversationConfigFactory(factory.django.DjangoModelFactory):
|
||||||
class Meta:
|
class Meta:
|
||||||
model = OpenAIProcessorConversationConfig
|
model = OpenAIProcessorConversationConfig
|
||||||
|
|
|
@ -24,7 +24,7 @@ from khoj.utils.constants import default_offline_chat_model
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def loaded_model():
|
def loaded_model():
|
||||||
return download_model(default_offline_chat_model)
|
return download_model(default_offline_chat_model, max_tokens=5000)
|
||||||
|
|
||||||
|
|
||||||
freezegun.configure(extend_ignore_list=["transformers"])
|
freezegun.configure(extend_ignore_list=["transformers"])
|
||||||
|
|
Loading…
Reference in a new issue