Merge pull request #584 from khoj-ai/features/enforce-usage-limits-conversation-type

Add a ConversationCommand rate limiter for the chat endpoint
This commit is contained in:
sabaimran 2023-12-17 11:20:35 +05:30 committed by GitHub
commit fefaa2271d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 74 additions and 28 deletions

View file

@ -77,17 +77,18 @@ class UserAuthenticationBackend(AuthenticationBackend):
.afirst()
)
if user:
if state.billing_enabled:
subscription_state = await aget_user_subscription_state(user)
subscribed = (
subscription_state == SubscriptionState.SUBSCRIBED.value
or subscription_state == SubscriptionState.TRIAL.value
or subscription_state == SubscriptionState.UNSUBSCRIBED.value
)
if subscribed:
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user)
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user)
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user)
if not state.billing_enabled:
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user)
subscription_state = await aget_user_subscription_state(user)
subscribed = (
subscription_state == SubscriptionState.SUBSCRIBED.value
or subscription_state == SubscriptionState.TRIAL.value
or subscription_state == SubscriptionState.UNSUBSCRIBED.value
)
if subscribed:
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user)
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user)
if len(request.headers.get("Authorization", "").split("Bearer ")) == 2:
# Get bearer token from header
bearer_token = request.headers["Authorization"].split("Bearer ")[1]
@ -99,19 +100,18 @@ class UserAuthenticationBackend(AuthenticationBackend):
.afirst()
)
if user_with_token:
if state.billing_enabled:
subscription_state = await aget_user_subscription_state(user_with_token.user)
subscribed = (
subscription_state == SubscriptionState.SUBSCRIBED.value
or subscription_state == SubscriptionState.TRIAL.value
or subscription_state == SubscriptionState.UNSUBSCRIBED.value
)
if subscribed:
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(
user_with_token.user
)
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user_with_token.user)
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user_with_token.user)
if not state.billing_enabled:
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user_with_token.user)
subscription_state = await aget_user_subscription_state(user_with_token.user)
subscribed = (
subscription_state == SubscriptionState.SUBSCRIBED.value
or subscription_state == SubscriptionState.TRIAL.value
or subscription_state == SubscriptionState.UNSUBSCRIBED.value
)
if subscribed:
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user_with_token.user)
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user_with_token.user)
if state.anonymous_mode:
user = await self.khojuser_manager.filter(username="default").prefetch_related("subscription").afirst()
if user:

View file

@ -401,7 +401,7 @@ class ConversationAdapters:
)
max_results = 3
all_questions = await sync_to_async(list)(all_questions)
all_questions = await sync_to_async(list)(all_questions) # type: ignore
if len(all_questions) < max_results:
return all_questions

View file

@ -642,6 +642,8 @@ To get started, just start typing below. You can also type / to see a list of co
flashStatusInChatInput("⛔️ Configure speech-to-text model on server.")
} else if (err.status === 422) {
flashStatusInChatInput("⛔️ Audio file to large to process.")
} else if (err.status === 429) {
flashStatusInChatInput("⛔️ " + err.statusText);
} else {
flashStatusInChatInput("⛔️ Failed to transcribe audio.")
}

View file

@ -46,6 +46,7 @@ from khoj.routers.helpers import (
is_ready_to_chat,
update_telemetry_state,
validate_conversation_config,
ConversationCommandRateLimiter,
)
from khoj.search_filter.date_filter import DateFilter
from khoj.search_filter.file_filter import FileFilter
@ -67,6 +68,7 @@ from khoj.utils.state import SearchType
# Initialize Router
api = APIRouter()
logger = logging.getLogger(__name__)
conversation_command_rate_limiter = ConversationCommandRateLimiter(trial_rate_limit=5, subscribed_rate_limit=100)
def map_config_to_object(content_source: str):
@ -604,7 +606,13 @@ async def chat_options(
@api.post("/transcribe")
@requires(["authenticated"])
async def transcribe(request: Request, common: CommonQueryParams, file: UploadFile = File(...)):
async def transcribe(
request: Request,
common: CommonQueryParams,
file: UploadFile = File(...),
rate_limiter_per_minute=Depends(ApiUserRateLimiter(requests=1, subscribed_requests=10, window=60)),
rate_limiter_per_day=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=600, window=60 * 60 * 24)),
):
user: KhojUser = request.user.object
audio_filename = f"{user.uuid}-{str(uuid.uuid4())}.webm"
user_message: str = None
@ -670,6 +678,8 @@ async def chat(
await is_ready_to_chat(user)
conversation_command = get_conversation_command(query=q, any_references=True)
conversation_command_rate_limiter.update_and_check_if_valid(request, conversation_command)
q = q.replace(f"/{conversation_command.value}", "").strip()
meta_log = (await ConversationAdapters.aget_conversation_by_user(user)).conversation_log

View file

@ -267,7 +267,7 @@ async def text_to_image(message: str) -> Tuple[Optional[str], int]:
)
image = response.data[0].b64_json
except openai.OpenAIError as e:
logger.error(f"Image Generation failed with {e.http_status}: {e.error}")
logger.error(f"Image Generation failed with {e}", exc_info=True)
status_code = 500
return image, status_code
@ -300,6 +300,40 @@ class ApiUserRateLimiter:
user_requests.append(time())
class ConversationCommandRateLimiter:
def __init__(self, trial_rate_limit: int, subscribed_rate_limit: int):
self.cache: Dict[str, Dict[str, List[float]]] = defaultdict(lambda: defaultdict(list))
self.trial_rate_limit = trial_rate_limit
self.subscribed_rate_limit = subscribed_rate_limit
self.restricted_commands = [ConversationCommand.Online, ConversationCommand.Image]
def update_and_check_if_valid(self, request: Request, conversation_command: ConversationCommand):
if state.billing_enabled is False:
return
if not request.user.is_authenticated:
return
if conversation_command not in self.restricted_commands:
return
user: KhojUser = request.user.object
user_cache = self.cache[user.uuid]
subscribed = has_required_scope(request, ["premium"])
user_cache[conversation_command].append(time())
# Remove requests outside of the 24-hr time window
cutoff = time() - 60 * 60 * 24
while user_cache[conversation_command] and user_cache[conversation_command][0] < cutoff:
user_cache[conversation_command].pop(0)
if subscribed and len(user_cache[conversation_command]) > self.subscribed_rate_limit:
raise HTTPException(status_code=429, detail="Too Many Requests")
if not subscribed and len(user_cache[conversation_command]) > self.trial_rate_limit:
raise HTTPException(status_code=429, detail="Too Many Requests. Subscribe to increase your rate limit.")
return
class ApiIndexedDataLimiter:
def __init__(
self,
@ -317,7 +351,7 @@ class ApiIndexedDataLimiter:
if state.billing_enabled is False:
return
subscribed = has_required_scope(request, ["premium"])
incoming_data_size_mb = 0
incoming_data_size_mb = 0.0
deletion_file_names = set()
if not request.user.is_authenticated: