Fix infer_max_tokens func when configured_max_tokens is set to None

This commit is contained in:
Debanjum Singh Solanky 2024-04-23 23:29:15 +05:30
parent 60658a8037
commit 8e77b3dc82

View file

@ -65,8 +65,9 @@ def load_model_from_cache(repo_id: str, filename: str, repo_type="models"):
return None
def infer_max_tokens(model_context_window: int, configured_max_tokens=math.inf) -> int:
def infer_max_tokens(model_context_window: int, configured_max_tokens=None) -> int:
"""Infer max prompt size based on device memory and max context window supported by the model"""
configured_max_tokens = math.inf if configured_max_tokens is None else configured_max_tokens
vram_based_n_ctx = int(get_device_memory() / 2e6) # based on heuristic
configured_max_tokens = configured_max_tokens or math.inf # do not use if set to None
return min(configured_max_tokens, vram_based_n_ctx, model_context_window)