Enable free tier users to have unlimited chats with the default chat model (#886)

- Allow free tier users to have unlimited chats with default chat model. It'll only be rate-limited and at the same rate as subscribed users
- In the server chat settings, replace the concept of default/summarizer models with default/advanced chat models. Use the advanced models as a default for subscribed users.
- For each `ChatModelOption' configuration, allow the admin to specify a separate value of `max_tokens' for subscribed users. This allows server admins to configure different max token limits for unsubscribed and subscribed users
- Show error message in web app when hit rate limit or other server errors
This commit is contained in:
sabaimran 2024-08-16 12:14:44 -07:00 committed by GitHub
parent 8dad9362e7
commit c0316a6b5d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 210 additions and 92 deletions

View file

@ -222,7 +222,20 @@ export default function Chat() {
try {
await readChatStream(response);
} catch (err) {
console.log(err);
console.error(err);
// Retrieve latest message being processed
const currentMessage = messages.find((message) => !message.completed);
if (!currentMessage) return;
// Render error message as current message
const errorMessage = (err as Error).message;
currentMessage.rawResponse = `Encountered Error: ${errorMessage}. Please try again later.`;
// Complete message streaming teardown properly
currentMessage.completed = true;
setMessages([...messages]);
setQueryToProcess("");
setProcessQuerySignal(false);
}
}

View file

@ -386,8 +386,6 @@ export default function ChatMessage(props: ChatMessageProps) {
preElement.prepend(copyButton);
});
console.log("render katex within the chat message");
renderMathInElement(messageRef.current, {
delimiters: [
{ left: "$$", right: "$$", display: true },

View file

@ -672,7 +672,15 @@ export default function SettingsView() {
};
const updateModel = (name: string) => async (id: string) => {
if (!userConfig?.is_active && name !== "search") return;
if (!userConfig?.is_active && name !== "search") {
toast({
title: `Model Update`,
description: `You need to be subscribed to update ${name} models`,
variant: "destructive",
});
return;
}
try {
const response = await fetch(`/api/model/${name}?id=` + id, {
method: "POST",
@ -1144,7 +1152,7 @@ export default function SettingsView() {
<ChatCircleText className="h-7 w-7 mr-2" />
Chat
</CardHeader>
<CardContent className="overflow-hidden pb-12 grid gap-8">
<CardContent className="overflow-hidden pb-12 grid gap-8 h-fit">
<p className="text-gray-400">
Pick the chat model to generate text responses
</p>
@ -1169,7 +1177,7 @@ export default function SettingsView() {
<FileMagnifyingGlass className="h-7 w-7 mr-2" />
Search
</CardHeader>
<CardContent className="overflow-hidden pb-12 grid gap-8">
<CardContent className="overflow-hidden pb-12 grid gap-8 h-fit">
<p className="text-gray-400">
Pick the search model to find your documents
</p>
@ -1190,7 +1198,7 @@ export default function SettingsView() {
<Palette className="h-7 w-7 mr-2" />
Paint
</CardHeader>
<CardContent className="overflow-hidden pb-12 grid gap-8">
<CardContent className="overflow-hidden pb-12 grid gap-8 h-fit">
<p className="text-gray-400">
Pick the paint model to generate image responses
</p>
@ -1217,7 +1225,7 @@ export default function SettingsView() {
<Waveform className="h-7 w-7 mr-2" />
Voice
</CardHeader>
<CardContent className="overflow-hidden pb-12 grid gap-8">
<CardContent className="overflow-hidden pb-12 grid gap-8 h-fit">
<p className="text-gray-400">
Pick the voice model to generate speech
responses

View file

@ -32,10 +32,9 @@ from khoj.database.adapters import (
ClientApplicationAdapters,
ConversationAdapters,
ProcessLockAdapters,
SubscriptionState,
aget_or_create_user_by_phone_number,
aget_user_by_phone_number,
aget_user_subscription_state,
ais_user_subscribed,
delete_user_requests,
get_all_users,
get_or_create_search_models,
@ -119,15 +118,7 @@ class UserAuthenticationBackend(AuthenticationBackend):
.afirst()
)
if user:
if not state.billing_enabled:
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user)
subscription_state = await aget_user_subscription_state(user)
subscribed = (
subscription_state == SubscriptionState.SUBSCRIBED.value
or subscription_state == SubscriptionState.TRIAL.value
or subscription_state == SubscriptionState.UNSUBSCRIBED.value
)
subscribed = await ais_user_subscribed(user)
if subscribed:
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user)
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user)
@ -144,15 +135,7 @@ class UserAuthenticationBackend(AuthenticationBackend):
.afirst()
)
if user_with_token:
if not state.billing_enabled:
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user_with_token.user)
subscription_state = await aget_user_subscription_state(user_with_token.user)
subscribed = (
subscription_state == SubscriptionState.SUBSCRIBED.value
or subscription_state == SubscriptionState.TRIAL.value
or subscription_state == SubscriptionState.UNSUBSCRIBED.value
)
subscribed = await ais_user_subscribed(user_with_token.user)
if subscribed:
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user_with_token.user)
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user_with_token.user)
@ -189,20 +172,10 @@ class UserAuthenticationBackend(AuthenticationBackend):
if user is None:
return AuthCredentials(), UnauthenticatedUser()
if not state.billing_enabled:
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user, client_application)
subscribed = await ais_user_subscribed(user)
subscription_state = await aget_user_subscription_state(user)
subscribed = (
subscription_state == SubscriptionState.SUBSCRIBED.value
or subscription_state == SubscriptionState.TRIAL.value
or subscription_state == SubscriptionState.UNSUBSCRIBED.value
)
if subscribed:
return (
AuthCredentials(["authenticated", "premium"]),
AuthenticatedKhojUser(user, client_application),
)
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user, client_application)
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user, client_application)
# No auth required if server in anonymous mode

View file

@ -300,6 +300,38 @@ async def aget_user_subscription_state(user: KhojUser) -> str:
return subscription_to_state(user_subscription)
async def ais_user_subscribed(user: KhojUser) -> bool:
"""
Get whether the user is subscribed
"""
if not state.billing_enabled or state.anonymous_mode:
return True
subscription_state = await aget_user_subscription_state(user)
subscribed = (
subscription_state == SubscriptionState.SUBSCRIBED.value
or subscription_state == SubscriptionState.TRIAL.value
or subscription_state == SubscriptionState.UNSUBSCRIBED.value
)
return subscribed
def is_user_subscribed(user: KhojUser) -> bool:
"""
Get whether the user is subscribed
"""
if not state.billing_enabled or state.anonymous_mode:
return True
subscription_state = get_user_subscription_state(user.email)
subscribed = (
subscription_state == SubscriptionState.SUBSCRIBED.value
or subscription_state == SubscriptionState.TRIAL.value
or subscription_state == SubscriptionState.UNSUBSCRIBED.value
)
return subscribed
async def get_user_by_email(email: str) -> KhojUser:
return await KhojUser.objects.filter(email=email).afirst()
@ -751,17 +783,23 @@ class ConversationAdapters:
@staticmethod
def get_conversation_config(user: KhojUser):
subscribed = is_user_subscribed(user)
if not subscribed:
return ConversationAdapters.get_default_conversation_config()
config = UserConversationConfig.objects.filter(user=user).first()
if not config:
return None
return config.setting
if config:
return config.setting
return ConversationAdapters.get_advanced_conversation_config()
@staticmethod
async def aget_conversation_config(user: KhojUser):
subscribed = await ais_user_subscribed(user)
if not subscribed:
return await ConversationAdapters.aget_default_conversation_config()
config = await UserConversationConfig.objects.filter(user=user).prefetch_related("setting").afirst()
if not config:
return None
return config.setting
if config:
return config.setting
return ConversationAdapters.aget_advanced_conversation_config()
@staticmethod
async def aget_voice_model_config(user: KhojUser) -> Optional[VoiceModelOption]:
@ -784,35 +822,38 @@ class ConversationAdapters:
@staticmethod
def get_default_conversation_config():
server_chat_settings = ServerChatSettings.objects.first()
if server_chat_settings is None or server_chat_settings.default_model is None:
if server_chat_settings is None or server_chat_settings.chat_default is None:
return ChatModelOptions.objects.filter().first()
return server_chat_settings.default_model
return server_chat_settings.chat_default
@staticmethod
async def aget_default_conversation_config():
server_chat_settings: ServerChatSettings = (
await ServerChatSettings.objects.filter()
.prefetch_related("default_model", "default_model__openai_config")
.prefetch_related("chat_default", "chat_default__openai_config")
.afirst()
)
if server_chat_settings is None or server_chat_settings.default_model is None:
if server_chat_settings is None or server_chat_settings.chat_default is None:
return await ChatModelOptions.objects.filter().prefetch_related("openai_config").afirst()
return server_chat_settings.default_model
return server_chat_settings.chat_default
@staticmethod
async def aget_summarizer_conversation_config():
def get_advanced_conversation_config():
server_chat_settings = ServerChatSettings.objects.first()
if server_chat_settings is None or server_chat_settings.chat_advanced is None:
return ConversationAdapters.get_default_conversation_config()
return server_chat_settings.chat_advanced
@staticmethod
async def aget_advanced_conversation_config():
server_chat_settings: ServerChatSettings = (
await ServerChatSettings.objects.filter()
.prefetch_related(
"summarizer_model", "default_model", "default_model__openai_config", "summarizer_model__openai_config"
)
.prefetch_related("chat_advanced", "chat_advanced__openai_config")
.afirst()
)
if server_chat_settings is None or (
server_chat_settings.summarizer_model is None and server_chat_settings.default_model is None
):
return await ChatModelOptions.objects.filter().prefetch_related("openai_config").afirst()
return server_chat_settings.summarizer_model or server_chat_settings.default_model
if server_chat_settings is None or server_chat_settings.chat_advanced is None:
return await ConversationAdapters.aget_default_conversation_config()
return server_chat_settings.chat_advanced
@staticmethod
def create_conversation_from_public_conversation(

View file

@ -26,6 +26,7 @@ from khoj.database.models import (
SpeechToTextModelOptions,
Subscription,
TextToImageModelConfig,
UserConversationConfig,
UserSearchModelConfig,
UserVoiceModelConfig,
VoiceModelOption,
@ -101,6 +102,7 @@ admin.site.register(GithubConfig)
admin.site.register(NotionConfig)
admin.site.register(UserVoiceModelConfig)
admin.site.register(VoiceModelOption)
admin.site.register(UserConversationConfig)
@admin.register(Agent)
@ -191,8 +193,8 @@ class SearchModelConfigAdmin(admin.ModelAdmin):
@admin.register(ServerChatSettings)
class ServerChatSettingsAdmin(admin.ModelAdmin):
list_display = (
"default_model",
"summarizer_model",
"chat_default",
"chat_advanced",
)

View file

@ -0,0 +1,51 @@
# Generated by Django 5.0.7 on 2024-08-16 18:18
import django.db.models.deletion
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("database", "0056_searchmodelconfig_cross_encoder_model_config"),
]
operations = [
migrations.RenameField(
model_name="serverchatsettings",
old_name="default_model",
new_name="chat_default",
),
migrations.RemoveField(
model_name="serverchatsettings",
name="summarizer_model",
),
migrations.AddField(
model_name="chatmodeloptions",
name="subscribed_max_prompt_size",
field=models.IntegerField(blank=True, default=None, null=True),
),
migrations.AddField(
model_name="serverchatsettings",
name="chat_advanced",
field=models.ForeignKey(
blank=True,
default=None,
null=True,
on_delete=django.db.models.deletion.CASCADE,
related_name="chat_advanced",
to="database.chatmodeloptions",
),
),
migrations.AlterField(
model_name="serverchatsettings",
name="chat_default",
field=models.ForeignKey(
blank=True,
default=None,
null=True,
on_delete=django.db.models.deletion.CASCADE,
related_name="chat_default",
to="database.chatmodeloptions",
),
),
]

View file

@ -89,6 +89,7 @@ class ChatModelOptions(BaseModel):
ANTHROPIC = "anthropic"
max_prompt_size = models.IntegerField(default=None, null=True, blank=True)
subscribed_max_prompt_size = models.IntegerField(default=None, null=True, blank=True)
tokenizer = models.CharField(max_length=200, default=None, null=True, blank=True)
chat_model = models.CharField(max_length=200, default="NousResearch/Hermes-2-Pro-Mistral-7B-GGUF")
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE)
@ -205,11 +206,11 @@ class GithubRepoConfig(BaseModel):
class ServerChatSettings(BaseModel):
default_model = models.ForeignKey(
ChatModelOptions, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="default_model"
chat_default = models.ForeignKey(
ChatModelOptions, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="chat_default"
)
summarizer_model = models.ForeignKey(
ChatModelOptions, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="summarizer_model"
chat_advanced = models.ForeignKey(
ChatModelOptions, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="chat_advanced"
)

View file

@ -53,6 +53,7 @@ async def search_online(
conversation_history: dict,
location: LocationData,
user: KhojUser,
subscribed: bool = False,
send_status_func: Optional[Callable] = None,
custom_filters: List[str] = [],
):
@ -91,12 +92,15 @@ async def search_online(
# Read, extract relevant info from the retrieved web pages
if webpages:
webpage_links = [link for link, _, _ in webpages]
logger.info(f"🌐👀 Reading web pages at: {list(webpage_links)}")
logger.info(f"Reading web pages at: {list(webpage_links)}")
if send_status_func:
webpage_links_str = "\n- " + "\n- ".join(list(webpage_links))
async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"):
yield {ChatEvent.STATUS: event}
tasks = [read_webpage_and_extract_content(subquery, link, content) for link, subquery, content in webpages]
tasks = [
read_webpage_and_extract_content(subquery, link, content, subscribed=subscribed)
for link, subquery, content in webpages
]
results = await asyncio.gather(*tasks)
# Collect extracted info from the retrieved web pages
@ -132,6 +136,7 @@ async def read_webpages(
conversation_history: dict,
location: LocationData,
user: KhojUser,
subscribed: bool = False,
send_status_func: Optional[Callable] = None,
):
"Infer web pages to read from the query and extract relevant information from them"
@ -146,7 +151,7 @@ async def read_webpages(
webpage_links_str = "\n- " + "\n- ".join(list(urls))
async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"):
yield {ChatEvent.STATUS: event}
tasks = [read_webpage_and_extract_content(query, url) for url in urls]
tasks = [read_webpage_and_extract_content(query, url, subscribed=subscribed) for url in urls]
results = await asyncio.gather(*tasks)
response: Dict[str, Dict] = defaultdict(dict)
@ -157,14 +162,14 @@ async def read_webpages(
async def read_webpage_and_extract_content(
subquery: str, url: str, content: str = None
subquery: str, url: str, content: str = None, subscribed: bool = False
) -> Tuple[str, Union[None, str], str]:
try:
if is_none_or_empty(content):
with timer(f"Reading web page at '{url}' took", logger):
content = await read_webpage_with_olostep(url) if OLOSTEP_API_KEY else await read_webpage_with_jina(url)
with timer(f"Extracting relevant information from web page at '{url}' took", logger):
extracted_info = await extract_relevant_info(subquery, content)
extracted_info = await extract_relevant_info(subquery, content, subscribed=subscribed)
return subquery, extracted_info, url
except Exception as e:
logger.error(f"Failed to read web page at '{url}' with {e}")

View file

@ -4,14 +4,14 @@ import logging
import time
from datetime import datetime
from functools import partial
from typing import Any, Dict, List, Optional
from typing import Dict, Optional
from urllib.parse import unquote
from asgiref.sync import sync_to_async
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.requests import Request
from fastapi.responses import Response, StreamingResponse
from starlette.authentication import requires
from starlette.authentication import has_required_scope, requires
from khoj.app.settings import ALLOWED_HOSTS
from khoj.database.adapters import (
@ -59,7 +59,7 @@ from khoj.utils.rawconfig import FileFilterRequest, FilesFilterRequest, Location
# Initialize Router
logger = logging.getLogger(__name__)
conversation_command_rate_limiter = ConversationCommandRateLimiter(
trial_rate_limit=2, subscribed_rate_limit=100, slug="command"
trial_rate_limit=100, subscribed_rate_limit=100, slug="command"
)
@ -532,10 +532,10 @@ async def chat(
country: Optional[str] = None,
timezone: Optional[str] = None,
rate_limiter_per_minute=Depends(
ApiUserRateLimiter(requests=5, subscribed_requests=60, window=60, slug="chat_minute")
ApiUserRateLimiter(requests=60, subscribed_requests=60, window=60, slug="chat_minute")
),
rate_limiter_per_day=Depends(
ApiUserRateLimiter(requests=5, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day")
ApiUserRateLimiter(requests=600, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day")
),
):
async def event_generator(q: str):
@ -544,6 +544,7 @@ async def chat(
chat_metadata: dict = {}
connection_alive = True
user: KhojUser = request.user.object
subscribed: bool = has_required_scope(request, ["premium"])
event_delimiter = "␃🔚␗"
q = unquote(q)
@ -632,7 +633,9 @@ async def chat(
is_automated_task = conversation_commands == [ConversationCommand.AutomatedTask]
if conversation_commands == [ConversationCommand.Default] or is_automated_task:
conversation_commands = await aget_relevant_information_sources(q, meta_log, is_automated_task)
conversation_commands = await aget_relevant_information_sources(
q, meta_log, is_automated_task, subscribed=subscribed
)
conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands])
async for result in send_event(
ChatEvent.STATUS, f"**Chose Data Sources to Search:** {conversation_commands_str}"
@ -687,7 +690,7 @@ async def chat(
):
yield result
response = await extract_relevant_summary(q, contextual_data)
response = await extract_relevant_summary(q, contextual_data, subscribed=subscribed)
response_log = str(response)
async for result in send_llm_response(response_log):
yield result
@ -792,7 +795,13 @@ async def chat(
if ConversationCommand.Online in conversation_commands:
try:
async for result in search_online(
defiltered_query, meta_log, location, user, partial(send_event, ChatEvent.STATUS), custom_filters
defiltered_query,
meta_log,
location,
user,
subscribed,
partial(send_event, ChatEvent.STATUS),
custom_filters,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
@ -809,7 +818,7 @@ async def chat(
if ConversationCommand.Webpage in conversation_commands:
try:
async for result in read_webpages(
defiltered_query, meta_log, location, user, partial(send_event, ChatEvent.STATUS)
defiltered_query, meta_log, location, user, subscribed, partial(send_event, ChatEvent.STATUS)
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
@ -853,6 +862,7 @@ async def chat(
location_data=location,
references=compiled_references,
online_results=online_results,
subscribed=subscribed,
send_status_func=partial(send_event, ChatEvent.STATUS),
):
if isinstance(result, dict) and ChatEvent.STATUS in result:

View file

@ -252,7 +252,7 @@ async def acreate_title_from_query(query: str) -> str:
return response.strip()
async def aget_relevant_information_sources(query: str, conversation_history: dict, is_task: bool):
async def aget_relevant_information_sources(query: str, conversation_history: dict, is_task: bool, subscribed: bool):
"""
Given a query, determine which of the available tools the agent should use in order to answer appropriately.
"""
@ -273,7 +273,9 @@ async def aget_relevant_information_sources(query: str, conversation_history: di
)
with timer("Chat actor: Infer information sources to refer", logger):
response = await send_message_to_model_wrapper(relevant_tools_prompt, response_type="json_object")
response = await send_message_to_model_wrapper(
relevant_tools_prompt, response_type="json_object", subscribed=subscribed
)
try:
response = response.strip()
@ -434,7 +436,7 @@ async def schedule_query(q: str, conversation_history: dict) -> Tuple[str, ...]:
raise AssertionError(f"Invalid response for scheduling query: {raw_response}")
async def extract_relevant_info(q: str, corpus: str) -> Union[str, None]:
async def extract_relevant_info(q: str, corpus: str, subscribed: bool) -> Union[str, None]:
"""
Extract relevant information for a given query from the target corpus
"""
@ -447,18 +449,19 @@ async def extract_relevant_info(q: str, corpus: str) -> Union[str, None]:
corpus=corpus.strip(),
)
summarizer_model: ChatModelOptions = await ConversationAdapters.aget_summarizer_conversation_config()
chat_model: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config()
with timer("Chat actor: Extract relevant information from data", logger):
response = await send_message_to_model_wrapper(
extract_relevant_information,
prompts.system_prompt_extract_relevant_information,
chat_model_option=summarizer_model,
chat_model_option=chat_model,
subscribed=subscribed,
)
return response.strip()
async def extract_relevant_summary(q: str, corpus: str) -> Union[str, None]:
async def extract_relevant_summary(q: str, corpus: str, subscribed: bool = False) -> Union[str, None]:
"""
Extract relevant information for a given query from the target corpus
"""
@ -471,13 +474,14 @@ async def extract_relevant_summary(q: str, corpus: str) -> Union[str, None]:
corpus=corpus.strip(),
)
summarizer_model: ChatModelOptions = await ConversationAdapters.aget_summarizer_conversation_config()
chat_model: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config()
with timer("Chat actor: Extract relevant information from data", logger):
response = await send_message_to_model_wrapper(
extract_relevant_information,
prompts.system_prompt_extract_relevant_summary,
chat_model_option=summarizer_model,
chat_model_option=chat_model,
subscribed=subscribed,
)
return response.strip()
@ -489,6 +493,7 @@ async def generate_better_image_prompt(
note_references: List[Dict[str, Any]],
online_results: Optional[dict] = None,
model_type: Optional[str] = None,
subscribed: bool = False,
) -> str:
"""
Generate a better image prompt from the given query
@ -533,10 +538,12 @@ async def generate_better_image_prompt(
online_results=simplified_online_results,
)
summarizer_model: ChatModelOptions = await ConversationAdapters.aget_summarizer_conversation_config()
chat_model: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config()
with timer("Chat actor: Generate contextual image prompt", logger):
response = await send_message_to_model_wrapper(image_prompt, chat_model_option=summarizer_model)
response = await send_message_to_model_wrapper(
image_prompt, chat_model_option=chat_model, subscribed=subscribed
)
response = response.strip()
if response.startswith(('"', "'")) and response.endswith(('"', "'")):
response = response[1:-1]
@ -549,13 +556,18 @@ async def send_message_to_model_wrapper(
system_message: str = "",
response_type: str = "text",
chat_model_option: ChatModelOptions = None,
subscribed: bool = False,
):
conversation_config: ChatModelOptions = (
chat_model_option or await ConversationAdapters.aget_default_conversation_config()
)
chat_model = conversation_config.chat_model
max_tokens = conversation_config.max_prompt_size
max_tokens = (
conversation_config.subscribed_max_prompt_size
if subscribed and conversation_config.subscribed_max_prompt_size
else conversation_config.max_prompt_size
)
tokenizer = conversation_config.tokenizer
if conversation_config.model_type == "offline":
@ -786,6 +798,7 @@ async def text_to_image(
location_data: LocationData,
references: List[Dict[str, Any]],
online_results: Dict[str, Any],
subscribed: bool = False,
send_status_func: Optional[Callable] = None,
):
status_code = 200
@ -822,6 +835,7 @@ async def text_to_image(
note_references=references,
online_results=online_results,
model_type=text_to_image_config.model_type,
subscribed=subscribed,
)
if send_status_func:
@ -1359,7 +1373,9 @@ def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False)
current_notion_config = get_user_notion_config(user)
notion_token = current_notion_config.token if current_notion_config else ""
selected_chat_model_config = ConversationAdapters.get_conversation_config(user)
selected_chat_model_config = (
ConversationAdapters.get_conversation_config(user) or ConversationAdapters.get_default_conversation_config()
)
chat_models = ConversationAdapters.get_conversation_processor_options().all()
chat_model_options = list()
for chat_model in chat_models: