mirror of
https://github.com/khoj-ai/khoj.git
synced 2025-02-17 08:04:21 +00:00
Enable free tier users to have unlimited chats with the default chat model (#886)
- Allow free tier users to have unlimited chats with default chat model. It'll only be rate-limited and at the same rate as subscribed users - In the server chat settings, replace the concept of default/summarizer models with default/advanced chat models. Use the advanced models as a default for subscribed users. - For each `ChatModelOption' configuration, allow the admin to specify a separate value of `max_tokens' for subscribed users. This allows server admins to configure different max token limits for unsubscribed and subscribed users - Show error message in web app when hit rate limit or other server errors
This commit is contained in:
parent
8dad9362e7
commit
c0316a6b5d
11 changed files with 210 additions and 92 deletions
|
@ -222,7 +222,20 @@ export default function Chat() {
|
|||
try {
|
||||
await readChatStream(response);
|
||||
} catch (err) {
|
||||
console.log(err);
|
||||
console.error(err);
|
||||
// Retrieve latest message being processed
|
||||
const currentMessage = messages.find((message) => !message.completed);
|
||||
if (!currentMessage) return;
|
||||
|
||||
// Render error message as current message
|
||||
const errorMessage = (err as Error).message;
|
||||
currentMessage.rawResponse = `Encountered Error: ${errorMessage}. Please try again later.`;
|
||||
|
||||
// Complete message streaming teardown properly
|
||||
currentMessage.completed = true;
|
||||
setMessages([...messages]);
|
||||
setQueryToProcess("");
|
||||
setProcessQuerySignal(false);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -386,8 +386,6 @@ export default function ChatMessage(props: ChatMessageProps) {
|
|||
preElement.prepend(copyButton);
|
||||
});
|
||||
|
||||
console.log("render katex within the chat message");
|
||||
|
||||
renderMathInElement(messageRef.current, {
|
||||
delimiters: [
|
||||
{ left: "$$", right: "$$", display: true },
|
||||
|
|
|
@ -672,7 +672,15 @@ export default function SettingsView() {
|
|||
};
|
||||
|
||||
const updateModel = (name: string) => async (id: string) => {
|
||||
if (!userConfig?.is_active && name !== "search") return;
|
||||
if (!userConfig?.is_active && name !== "search") {
|
||||
toast({
|
||||
title: `Model Update`,
|
||||
description: `You need to be subscribed to update ${name} models`,
|
||||
variant: "destructive",
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await fetch(`/api/model/${name}?id=` + id, {
|
||||
method: "POST",
|
||||
|
@ -1144,7 +1152,7 @@ export default function SettingsView() {
|
|||
<ChatCircleText className="h-7 w-7 mr-2" />
|
||||
Chat
|
||||
</CardHeader>
|
||||
<CardContent className="overflow-hidden pb-12 grid gap-8">
|
||||
<CardContent className="overflow-hidden pb-12 grid gap-8 h-fit">
|
||||
<p className="text-gray-400">
|
||||
Pick the chat model to generate text responses
|
||||
</p>
|
||||
|
@ -1169,7 +1177,7 @@ export default function SettingsView() {
|
|||
<FileMagnifyingGlass className="h-7 w-7 mr-2" />
|
||||
Search
|
||||
</CardHeader>
|
||||
<CardContent className="overflow-hidden pb-12 grid gap-8">
|
||||
<CardContent className="overflow-hidden pb-12 grid gap-8 h-fit">
|
||||
<p className="text-gray-400">
|
||||
Pick the search model to find your documents
|
||||
</p>
|
||||
|
@ -1190,7 +1198,7 @@ export default function SettingsView() {
|
|||
<Palette className="h-7 w-7 mr-2" />
|
||||
Paint
|
||||
</CardHeader>
|
||||
<CardContent className="overflow-hidden pb-12 grid gap-8">
|
||||
<CardContent className="overflow-hidden pb-12 grid gap-8 h-fit">
|
||||
<p className="text-gray-400">
|
||||
Pick the paint model to generate image responses
|
||||
</p>
|
||||
|
@ -1217,7 +1225,7 @@ export default function SettingsView() {
|
|||
<Waveform className="h-7 w-7 mr-2" />
|
||||
Voice
|
||||
</CardHeader>
|
||||
<CardContent className="overflow-hidden pb-12 grid gap-8">
|
||||
<CardContent className="overflow-hidden pb-12 grid gap-8 h-fit">
|
||||
<p className="text-gray-400">
|
||||
Pick the voice model to generate speech
|
||||
responses
|
||||
|
|
|
@ -32,10 +32,9 @@ from khoj.database.adapters import (
|
|||
ClientApplicationAdapters,
|
||||
ConversationAdapters,
|
||||
ProcessLockAdapters,
|
||||
SubscriptionState,
|
||||
aget_or_create_user_by_phone_number,
|
||||
aget_user_by_phone_number,
|
||||
aget_user_subscription_state,
|
||||
ais_user_subscribed,
|
||||
delete_user_requests,
|
||||
get_all_users,
|
||||
get_or_create_search_models,
|
||||
|
@ -119,15 +118,7 @@ class UserAuthenticationBackend(AuthenticationBackend):
|
|||
.afirst()
|
||||
)
|
||||
if user:
|
||||
if not state.billing_enabled:
|
||||
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user)
|
||||
|
||||
subscription_state = await aget_user_subscription_state(user)
|
||||
subscribed = (
|
||||
subscription_state == SubscriptionState.SUBSCRIBED.value
|
||||
or subscription_state == SubscriptionState.TRIAL.value
|
||||
or subscription_state == SubscriptionState.UNSUBSCRIBED.value
|
||||
)
|
||||
subscribed = await ais_user_subscribed(user)
|
||||
if subscribed:
|
||||
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user)
|
||||
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user)
|
||||
|
@ -144,15 +135,7 @@ class UserAuthenticationBackend(AuthenticationBackend):
|
|||
.afirst()
|
||||
)
|
||||
if user_with_token:
|
||||
if not state.billing_enabled:
|
||||
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user_with_token.user)
|
||||
|
||||
subscription_state = await aget_user_subscription_state(user_with_token.user)
|
||||
subscribed = (
|
||||
subscription_state == SubscriptionState.SUBSCRIBED.value
|
||||
or subscription_state == SubscriptionState.TRIAL.value
|
||||
or subscription_state == SubscriptionState.UNSUBSCRIBED.value
|
||||
)
|
||||
subscribed = await ais_user_subscribed(user_with_token.user)
|
||||
if subscribed:
|
||||
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user_with_token.user)
|
||||
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user_with_token.user)
|
||||
|
@ -189,20 +172,10 @@ class UserAuthenticationBackend(AuthenticationBackend):
|
|||
if user is None:
|
||||
return AuthCredentials(), UnauthenticatedUser()
|
||||
|
||||
if not state.billing_enabled:
|
||||
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user, client_application)
|
||||
subscribed = await ais_user_subscribed(user)
|
||||
|
||||
subscription_state = await aget_user_subscription_state(user)
|
||||
subscribed = (
|
||||
subscription_state == SubscriptionState.SUBSCRIBED.value
|
||||
or subscription_state == SubscriptionState.TRIAL.value
|
||||
or subscription_state == SubscriptionState.UNSUBSCRIBED.value
|
||||
)
|
||||
if subscribed:
|
||||
return (
|
||||
AuthCredentials(["authenticated", "premium"]),
|
||||
AuthenticatedKhojUser(user, client_application),
|
||||
)
|
||||
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user, client_application)
|
||||
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user, client_application)
|
||||
|
||||
# No auth required if server in anonymous mode
|
||||
|
|
|
@ -300,6 +300,38 @@ async def aget_user_subscription_state(user: KhojUser) -> str:
|
|||
return subscription_to_state(user_subscription)
|
||||
|
||||
|
||||
async def ais_user_subscribed(user: KhojUser) -> bool:
|
||||
"""
|
||||
Get whether the user is subscribed
|
||||
"""
|
||||
if not state.billing_enabled or state.anonymous_mode:
|
||||
return True
|
||||
|
||||
subscription_state = await aget_user_subscription_state(user)
|
||||
subscribed = (
|
||||
subscription_state == SubscriptionState.SUBSCRIBED.value
|
||||
or subscription_state == SubscriptionState.TRIAL.value
|
||||
or subscription_state == SubscriptionState.UNSUBSCRIBED.value
|
||||
)
|
||||
return subscribed
|
||||
|
||||
|
||||
def is_user_subscribed(user: KhojUser) -> bool:
|
||||
"""
|
||||
Get whether the user is subscribed
|
||||
"""
|
||||
if not state.billing_enabled or state.anonymous_mode:
|
||||
return True
|
||||
|
||||
subscription_state = get_user_subscription_state(user.email)
|
||||
subscribed = (
|
||||
subscription_state == SubscriptionState.SUBSCRIBED.value
|
||||
or subscription_state == SubscriptionState.TRIAL.value
|
||||
or subscription_state == SubscriptionState.UNSUBSCRIBED.value
|
||||
)
|
||||
return subscribed
|
||||
|
||||
|
||||
async def get_user_by_email(email: str) -> KhojUser:
|
||||
return await KhojUser.objects.filter(email=email).afirst()
|
||||
|
||||
|
@ -751,17 +783,23 @@ class ConversationAdapters:
|
|||
|
||||
@staticmethod
|
||||
def get_conversation_config(user: KhojUser):
|
||||
subscribed = is_user_subscribed(user)
|
||||
if not subscribed:
|
||||
return ConversationAdapters.get_default_conversation_config()
|
||||
config = UserConversationConfig.objects.filter(user=user).first()
|
||||
if not config:
|
||||
return None
|
||||
return config.setting
|
||||
if config:
|
||||
return config.setting
|
||||
return ConversationAdapters.get_advanced_conversation_config()
|
||||
|
||||
@staticmethod
|
||||
async def aget_conversation_config(user: KhojUser):
|
||||
subscribed = await ais_user_subscribed(user)
|
||||
if not subscribed:
|
||||
return await ConversationAdapters.aget_default_conversation_config()
|
||||
config = await UserConversationConfig.objects.filter(user=user).prefetch_related("setting").afirst()
|
||||
if not config:
|
||||
return None
|
||||
return config.setting
|
||||
if config:
|
||||
return config.setting
|
||||
return ConversationAdapters.aget_advanced_conversation_config()
|
||||
|
||||
@staticmethod
|
||||
async def aget_voice_model_config(user: KhojUser) -> Optional[VoiceModelOption]:
|
||||
|
@ -784,35 +822,38 @@ class ConversationAdapters:
|
|||
@staticmethod
|
||||
def get_default_conversation_config():
|
||||
server_chat_settings = ServerChatSettings.objects.first()
|
||||
if server_chat_settings is None or server_chat_settings.default_model is None:
|
||||
if server_chat_settings is None or server_chat_settings.chat_default is None:
|
||||
return ChatModelOptions.objects.filter().first()
|
||||
return server_chat_settings.default_model
|
||||
return server_chat_settings.chat_default
|
||||
|
||||
@staticmethod
|
||||
async def aget_default_conversation_config():
|
||||
server_chat_settings: ServerChatSettings = (
|
||||
await ServerChatSettings.objects.filter()
|
||||
.prefetch_related("default_model", "default_model__openai_config")
|
||||
.prefetch_related("chat_default", "chat_default__openai_config")
|
||||
.afirst()
|
||||
)
|
||||
if server_chat_settings is None or server_chat_settings.default_model is None:
|
||||
if server_chat_settings is None or server_chat_settings.chat_default is None:
|
||||
return await ChatModelOptions.objects.filter().prefetch_related("openai_config").afirst()
|
||||
return server_chat_settings.default_model
|
||||
return server_chat_settings.chat_default
|
||||
|
||||
@staticmethod
|
||||
async def aget_summarizer_conversation_config():
|
||||
def get_advanced_conversation_config():
|
||||
server_chat_settings = ServerChatSettings.objects.first()
|
||||
if server_chat_settings is None or server_chat_settings.chat_advanced is None:
|
||||
return ConversationAdapters.get_default_conversation_config()
|
||||
return server_chat_settings.chat_advanced
|
||||
|
||||
@staticmethod
|
||||
async def aget_advanced_conversation_config():
|
||||
server_chat_settings: ServerChatSettings = (
|
||||
await ServerChatSettings.objects.filter()
|
||||
.prefetch_related(
|
||||
"summarizer_model", "default_model", "default_model__openai_config", "summarizer_model__openai_config"
|
||||
)
|
||||
.prefetch_related("chat_advanced", "chat_advanced__openai_config")
|
||||
.afirst()
|
||||
)
|
||||
if server_chat_settings is None or (
|
||||
server_chat_settings.summarizer_model is None and server_chat_settings.default_model is None
|
||||
):
|
||||
return await ChatModelOptions.objects.filter().prefetch_related("openai_config").afirst()
|
||||
return server_chat_settings.summarizer_model or server_chat_settings.default_model
|
||||
if server_chat_settings is None or server_chat_settings.chat_advanced is None:
|
||||
return await ConversationAdapters.aget_default_conversation_config()
|
||||
return server_chat_settings.chat_advanced
|
||||
|
||||
@staticmethod
|
||||
def create_conversation_from_public_conversation(
|
||||
|
|
|
@ -26,6 +26,7 @@ from khoj.database.models import (
|
|||
SpeechToTextModelOptions,
|
||||
Subscription,
|
||||
TextToImageModelConfig,
|
||||
UserConversationConfig,
|
||||
UserSearchModelConfig,
|
||||
UserVoiceModelConfig,
|
||||
VoiceModelOption,
|
||||
|
@ -101,6 +102,7 @@ admin.site.register(GithubConfig)
|
|||
admin.site.register(NotionConfig)
|
||||
admin.site.register(UserVoiceModelConfig)
|
||||
admin.site.register(VoiceModelOption)
|
||||
admin.site.register(UserConversationConfig)
|
||||
|
||||
|
||||
@admin.register(Agent)
|
||||
|
@ -191,8 +193,8 @@ class SearchModelConfigAdmin(admin.ModelAdmin):
|
|||
@admin.register(ServerChatSettings)
|
||||
class ServerChatSettingsAdmin(admin.ModelAdmin):
|
||||
list_display = (
|
||||
"default_model",
|
||||
"summarizer_model",
|
||||
"chat_default",
|
||||
"chat_advanced",
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,51 @@
|
|||
# Generated by Django 5.0.7 on 2024-08-16 18:18
|
||||
|
||||
import django.db.models.deletion
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("database", "0056_searchmodelconfig_cross_encoder_model_config"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.RenameField(
|
||||
model_name="serverchatsettings",
|
||||
old_name="default_model",
|
||||
new_name="chat_default",
|
||||
),
|
||||
migrations.RemoveField(
|
||||
model_name="serverchatsettings",
|
||||
name="summarizer_model",
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="chatmodeloptions",
|
||||
name="subscribed_max_prompt_size",
|
||||
field=models.IntegerField(blank=True, default=None, null=True),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="serverchatsettings",
|
||||
name="chat_advanced",
|
||||
field=models.ForeignKey(
|
||||
blank=True,
|
||||
default=None,
|
||||
null=True,
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
related_name="chat_advanced",
|
||||
to="database.chatmodeloptions",
|
||||
),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name="serverchatsettings",
|
||||
name="chat_default",
|
||||
field=models.ForeignKey(
|
||||
blank=True,
|
||||
default=None,
|
||||
null=True,
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
related_name="chat_default",
|
||||
to="database.chatmodeloptions",
|
||||
),
|
||||
),
|
||||
]
|
|
@ -89,6 +89,7 @@ class ChatModelOptions(BaseModel):
|
|||
ANTHROPIC = "anthropic"
|
||||
|
||||
max_prompt_size = models.IntegerField(default=None, null=True, blank=True)
|
||||
subscribed_max_prompt_size = models.IntegerField(default=None, null=True, blank=True)
|
||||
tokenizer = models.CharField(max_length=200, default=None, null=True, blank=True)
|
||||
chat_model = models.CharField(max_length=200, default="NousResearch/Hermes-2-Pro-Mistral-7B-GGUF")
|
||||
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE)
|
||||
|
@ -205,11 +206,11 @@ class GithubRepoConfig(BaseModel):
|
|||
|
||||
|
||||
class ServerChatSettings(BaseModel):
|
||||
default_model = models.ForeignKey(
|
||||
ChatModelOptions, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="default_model"
|
||||
chat_default = models.ForeignKey(
|
||||
ChatModelOptions, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="chat_default"
|
||||
)
|
||||
summarizer_model = models.ForeignKey(
|
||||
ChatModelOptions, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="summarizer_model"
|
||||
chat_advanced = models.ForeignKey(
|
||||
ChatModelOptions, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="chat_advanced"
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -53,6 +53,7 @@ async def search_online(
|
|||
conversation_history: dict,
|
||||
location: LocationData,
|
||||
user: KhojUser,
|
||||
subscribed: bool = False,
|
||||
send_status_func: Optional[Callable] = None,
|
||||
custom_filters: List[str] = [],
|
||||
):
|
||||
|
@ -91,12 +92,15 @@ async def search_online(
|
|||
# Read, extract relevant info from the retrieved web pages
|
||||
if webpages:
|
||||
webpage_links = [link for link, _, _ in webpages]
|
||||
logger.info(f"🌐👀 Reading web pages at: {list(webpage_links)}")
|
||||
logger.info(f"Reading web pages at: {list(webpage_links)}")
|
||||
if send_status_func:
|
||||
webpage_links_str = "\n- " + "\n- ".join(list(webpage_links))
|
||||
async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"):
|
||||
yield {ChatEvent.STATUS: event}
|
||||
tasks = [read_webpage_and_extract_content(subquery, link, content) for link, subquery, content in webpages]
|
||||
tasks = [
|
||||
read_webpage_and_extract_content(subquery, link, content, subscribed=subscribed)
|
||||
for link, subquery, content in webpages
|
||||
]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# Collect extracted info from the retrieved web pages
|
||||
|
@ -132,6 +136,7 @@ async def read_webpages(
|
|||
conversation_history: dict,
|
||||
location: LocationData,
|
||||
user: KhojUser,
|
||||
subscribed: bool = False,
|
||||
send_status_func: Optional[Callable] = None,
|
||||
):
|
||||
"Infer web pages to read from the query and extract relevant information from them"
|
||||
|
@ -146,7 +151,7 @@ async def read_webpages(
|
|||
webpage_links_str = "\n- " + "\n- ".join(list(urls))
|
||||
async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"):
|
||||
yield {ChatEvent.STATUS: event}
|
||||
tasks = [read_webpage_and_extract_content(query, url) for url in urls]
|
||||
tasks = [read_webpage_and_extract_content(query, url, subscribed=subscribed) for url in urls]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
response: Dict[str, Dict] = defaultdict(dict)
|
||||
|
@ -157,14 +162,14 @@ async def read_webpages(
|
|||
|
||||
|
||||
async def read_webpage_and_extract_content(
|
||||
subquery: str, url: str, content: str = None
|
||||
subquery: str, url: str, content: str = None, subscribed: bool = False
|
||||
) -> Tuple[str, Union[None, str], str]:
|
||||
try:
|
||||
if is_none_or_empty(content):
|
||||
with timer(f"Reading web page at '{url}' took", logger):
|
||||
content = await read_webpage_with_olostep(url) if OLOSTEP_API_KEY else await read_webpage_with_jina(url)
|
||||
with timer(f"Extracting relevant information from web page at '{url}' took", logger):
|
||||
extracted_info = await extract_relevant_info(subquery, content)
|
||||
extracted_info = await extract_relevant_info(subquery, content, subscribed=subscribed)
|
||||
return subquery, extracted_info, url
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to read web page at '{url}' with {e}")
|
||||
|
|
|
@ -4,14 +4,14 @@ import logging
|
|||
import time
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Dict, Optional
|
||||
from urllib.parse import unquote
|
||||
|
||||
from asgiref.sync import sync_to_async
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from fastapi.requests import Request
|
||||
from fastapi.responses import Response, StreamingResponse
|
||||
from starlette.authentication import requires
|
||||
from starlette.authentication import has_required_scope, requires
|
||||
|
||||
from khoj.app.settings import ALLOWED_HOSTS
|
||||
from khoj.database.adapters import (
|
||||
|
@ -59,7 +59,7 @@ from khoj.utils.rawconfig import FileFilterRequest, FilesFilterRequest, Location
|
|||
# Initialize Router
|
||||
logger = logging.getLogger(__name__)
|
||||
conversation_command_rate_limiter = ConversationCommandRateLimiter(
|
||||
trial_rate_limit=2, subscribed_rate_limit=100, slug="command"
|
||||
trial_rate_limit=100, subscribed_rate_limit=100, slug="command"
|
||||
)
|
||||
|
||||
|
||||
|
@ -532,10 +532,10 @@ async def chat(
|
|||
country: Optional[str] = None,
|
||||
timezone: Optional[str] = None,
|
||||
rate_limiter_per_minute=Depends(
|
||||
ApiUserRateLimiter(requests=5, subscribed_requests=60, window=60, slug="chat_minute")
|
||||
ApiUserRateLimiter(requests=60, subscribed_requests=60, window=60, slug="chat_minute")
|
||||
),
|
||||
rate_limiter_per_day=Depends(
|
||||
ApiUserRateLimiter(requests=5, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day")
|
||||
ApiUserRateLimiter(requests=600, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day")
|
||||
),
|
||||
):
|
||||
async def event_generator(q: str):
|
||||
|
@ -544,6 +544,7 @@ async def chat(
|
|||
chat_metadata: dict = {}
|
||||
connection_alive = True
|
||||
user: KhojUser = request.user.object
|
||||
subscribed: bool = has_required_scope(request, ["premium"])
|
||||
event_delimiter = "␃🔚␗"
|
||||
q = unquote(q)
|
||||
|
||||
|
@ -632,7 +633,9 @@ async def chat(
|
|||
is_automated_task = conversation_commands == [ConversationCommand.AutomatedTask]
|
||||
|
||||
if conversation_commands == [ConversationCommand.Default] or is_automated_task:
|
||||
conversation_commands = await aget_relevant_information_sources(q, meta_log, is_automated_task)
|
||||
conversation_commands = await aget_relevant_information_sources(
|
||||
q, meta_log, is_automated_task, subscribed=subscribed
|
||||
)
|
||||
conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands])
|
||||
async for result in send_event(
|
||||
ChatEvent.STATUS, f"**Chose Data Sources to Search:** {conversation_commands_str}"
|
||||
|
@ -687,7 +690,7 @@ async def chat(
|
|||
):
|
||||
yield result
|
||||
|
||||
response = await extract_relevant_summary(q, contextual_data)
|
||||
response = await extract_relevant_summary(q, contextual_data, subscribed=subscribed)
|
||||
response_log = str(response)
|
||||
async for result in send_llm_response(response_log):
|
||||
yield result
|
||||
|
@ -792,7 +795,13 @@ async def chat(
|
|||
if ConversationCommand.Online in conversation_commands:
|
||||
try:
|
||||
async for result in search_online(
|
||||
defiltered_query, meta_log, location, user, partial(send_event, ChatEvent.STATUS), custom_filters
|
||||
defiltered_query,
|
||||
meta_log,
|
||||
location,
|
||||
user,
|
||||
subscribed,
|
||||
partial(send_event, ChatEvent.STATUS),
|
||||
custom_filters,
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
yield result[ChatEvent.STATUS]
|
||||
|
@ -809,7 +818,7 @@ async def chat(
|
|||
if ConversationCommand.Webpage in conversation_commands:
|
||||
try:
|
||||
async for result in read_webpages(
|
||||
defiltered_query, meta_log, location, user, partial(send_event, ChatEvent.STATUS)
|
||||
defiltered_query, meta_log, location, user, subscribed, partial(send_event, ChatEvent.STATUS)
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
yield result[ChatEvent.STATUS]
|
||||
|
@ -853,6 +862,7 @@ async def chat(
|
|||
location_data=location,
|
||||
references=compiled_references,
|
||||
online_results=online_results,
|
||||
subscribed=subscribed,
|
||||
send_status_func=partial(send_event, ChatEvent.STATUS),
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
|
|
|
@ -252,7 +252,7 @@ async def acreate_title_from_query(query: str) -> str:
|
|||
return response.strip()
|
||||
|
||||
|
||||
async def aget_relevant_information_sources(query: str, conversation_history: dict, is_task: bool):
|
||||
async def aget_relevant_information_sources(query: str, conversation_history: dict, is_task: bool, subscribed: bool):
|
||||
"""
|
||||
Given a query, determine which of the available tools the agent should use in order to answer appropriately.
|
||||
"""
|
||||
|
@ -273,7 +273,9 @@ async def aget_relevant_information_sources(query: str, conversation_history: di
|
|||
)
|
||||
|
||||
with timer("Chat actor: Infer information sources to refer", logger):
|
||||
response = await send_message_to_model_wrapper(relevant_tools_prompt, response_type="json_object")
|
||||
response = await send_message_to_model_wrapper(
|
||||
relevant_tools_prompt, response_type="json_object", subscribed=subscribed
|
||||
)
|
||||
|
||||
try:
|
||||
response = response.strip()
|
||||
|
@ -434,7 +436,7 @@ async def schedule_query(q: str, conversation_history: dict) -> Tuple[str, ...]:
|
|||
raise AssertionError(f"Invalid response for scheduling query: {raw_response}")
|
||||
|
||||
|
||||
async def extract_relevant_info(q: str, corpus: str) -> Union[str, None]:
|
||||
async def extract_relevant_info(q: str, corpus: str, subscribed: bool) -> Union[str, None]:
|
||||
"""
|
||||
Extract relevant information for a given query from the target corpus
|
||||
"""
|
||||
|
@ -447,18 +449,19 @@ async def extract_relevant_info(q: str, corpus: str) -> Union[str, None]:
|
|||
corpus=corpus.strip(),
|
||||
)
|
||||
|
||||
summarizer_model: ChatModelOptions = await ConversationAdapters.aget_summarizer_conversation_config()
|
||||
chat_model: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config()
|
||||
|
||||
with timer("Chat actor: Extract relevant information from data", logger):
|
||||
response = await send_message_to_model_wrapper(
|
||||
extract_relevant_information,
|
||||
prompts.system_prompt_extract_relevant_information,
|
||||
chat_model_option=summarizer_model,
|
||||
chat_model_option=chat_model,
|
||||
subscribed=subscribed,
|
||||
)
|
||||
return response.strip()
|
||||
|
||||
|
||||
async def extract_relevant_summary(q: str, corpus: str) -> Union[str, None]:
|
||||
async def extract_relevant_summary(q: str, corpus: str, subscribed: bool = False) -> Union[str, None]:
|
||||
"""
|
||||
Extract relevant information for a given query from the target corpus
|
||||
"""
|
||||
|
@ -471,13 +474,14 @@ async def extract_relevant_summary(q: str, corpus: str) -> Union[str, None]:
|
|||
corpus=corpus.strip(),
|
||||
)
|
||||
|
||||
summarizer_model: ChatModelOptions = await ConversationAdapters.aget_summarizer_conversation_config()
|
||||
chat_model: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config()
|
||||
|
||||
with timer("Chat actor: Extract relevant information from data", logger):
|
||||
response = await send_message_to_model_wrapper(
|
||||
extract_relevant_information,
|
||||
prompts.system_prompt_extract_relevant_summary,
|
||||
chat_model_option=summarizer_model,
|
||||
chat_model_option=chat_model,
|
||||
subscribed=subscribed,
|
||||
)
|
||||
return response.strip()
|
||||
|
||||
|
@ -489,6 +493,7 @@ async def generate_better_image_prompt(
|
|||
note_references: List[Dict[str, Any]],
|
||||
online_results: Optional[dict] = None,
|
||||
model_type: Optional[str] = None,
|
||||
subscribed: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
Generate a better image prompt from the given query
|
||||
|
@ -533,10 +538,12 @@ async def generate_better_image_prompt(
|
|||
online_results=simplified_online_results,
|
||||
)
|
||||
|
||||
summarizer_model: ChatModelOptions = await ConversationAdapters.aget_summarizer_conversation_config()
|
||||
chat_model: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config()
|
||||
|
||||
with timer("Chat actor: Generate contextual image prompt", logger):
|
||||
response = await send_message_to_model_wrapper(image_prompt, chat_model_option=summarizer_model)
|
||||
response = await send_message_to_model_wrapper(
|
||||
image_prompt, chat_model_option=chat_model, subscribed=subscribed
|
||||
)
|
||||
response = response.strip()
|
||||
if response.startswith(('"', "'")) and response.endswith(('"', "'")):
|
||||
response = response[1:-1]
|
||||
|
@ -549,13 +556,18 @@ async def send_message_to_model_wrapper(
|
|||
system_message: str = "",
|
||||
response_type: str = "text",
|
||||
chat_model_option: ChatModelOptions = None,
|
||||
subscribed: bool = False,
|
||||
):
|
||||
conversation_config: ChatModelOptions = (
|
||||
chat_model_option or await ConversationAdapters.aget_default_conversation_config()
|
||||
)
|
||||
|
||||
chat_model = conversation_config.chat_model
|
||||
max_tokens = conversation_config.max_prompt_size
|
||||
max_tokens = (
|
||||
conversation_config.subscribed_max_prompt_size
|
||||
if subscribed and conversation_config.subscribed_max_prompt_size
|
||||
else conversation_config.max_prompt_size
|
||||
)
|
||||
tokenizer = conversation_config.tokenizer
|
||||
|
||||
if conversation_config.model_type == "offline":
|
||||
|
@ -786,6 +798,7 @@ async def text_to_image(
|
|||
location_data: LocationData,
|
||||
references: List[Dict[str, Any]],
|
||||
online_results: Dict[str, Any],
|
||||
subscribed: bool = False,
|
||||
send_status_func: Optional[Callable] = None,
|
||||
):
|
||||
status_code = 200
|
||||
|
@ -822,6 +835,7 @@ async def text_to_image(
|
|||
note_references=references,
|
||||
online_results=online_results,
|
||||
model_type=text_to_image_config.model_type,
|
||||
subscribed=subscribed,
|
||||
)
|
||||
|
||||
if send_status_func:
|
||||
|
@ -1359,7 +1373,9 @@ def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False)
|
|||
current_notion_config = get_user_notion_config(user)
|
||||
notion_token = current_notion_config.token if current_notion_config else ""
|
||||
|
||||
selected_chat_model_config = ConversationAdapters.get_conversation_config(user)
|
||||
selected_chat_model_config = (
|
||||
ConversationAdapters.get_conversation_config(user) or ConversationAdapters.get_default_conversation_config()
|
||||
)
|
||||
chat_models = ConversationAdapters.get_conversation_processor_options().all()
|
||||
chat_model_options = list()
|
||||
for chat_model in chat_models:
|
||||
|
|
Loading…
Add table
Reference in a new issue