Add support for OpenAI o1 model

Needed to handle the current API limitations of the o1 model.
Specifically its inability to stream responses
This commit is contained in:
Debanjum 2024-12-17 23:22:42 -08:00
parent b1c5c5bcc9
commit 37ae48d9cf
3 changed files with 16 additions and 10 deletions

View file

@ -58,13 +58,17 @@ def completion_with_backoff(
openai_clients[client_key] = client openai_clients[client_key] = client
formatted_messages = [{"role": message.role, "content": message.content} for message in messages] formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
stream = True
# Update request parameters for compatability with o1 model series # Update request parameters for compatability with o1 model series
# Refer: https://platform.openai.com/docs/guides/reasoning/beta-limitations # Refer: https://platform.openai.com/docs/guides/reasoning/beta-limitations
if model_name.startswith("o1"): stream = True
model_kwargs["stream_options"] = {"include_usage": True}
if model_name == "o1":
temperature = 1
stream = False
model_kwargs.pop("stream_options", None)
elif model_name.startswith("o1"):
temperature = 1 temperature = 1
model_kwargs.pop("stop", None)
model_kwargs.pop("response_format", None) model_kwargs.pop("response_format", None)
if os.getenv("KHOJ_LLM_SEED"): if os.getenv("KHOJ_LLM_SEED"):
@ -74,7 +78,6 @@ def completion_with_backoff(
messages=formatted_messages, # type: ignore messages=formatted_messages, # type: ignore
model=model_name, # type: ignore model=model_name, # type: ignore
stream=stream, stream=stream,
stream_options={"include_usage": True} if stream else {},
temperature=temperature, temperature=temperature,
timeout=20, timeout=20,
**model_kwargs, **model_kwargs,
@ -165,13 +168,17 @@ def llm_thread(
client = openai_clients[client_key] client = openai_clients[client_key]
formatted_messages = [{"role": message.role, "content": message.content} for message in messages] formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
stream = True
# Update request parameters for compatability with o1 model series # Update request parameters for compatability with o1 model series
# Refer: https://platform.openai.com/docs/guides/reasoning/beta-limitations # Refer: https://platform.openai.com/docs/guides/reasoning/beta-limitations
if model_name.startswith("o1"): stream = True
model_kwargs["stream_options"] = {"include_usage": True}
if model_name == "o1":
temperature = 1
stream = False
model_kwargs.pop("stream_options", None)
elif model_name.startswith("o1-"):
temperature = 1 temperature = 1
model_kwargs.pop("stop", None)
model_kwargs.pop("response_format", None) model_kwargs.pop("response_format", None)
if os.getenv("KHOJ_LLM_SEED"): if os.getenv("KHOJ_LLM_SEED"):
@ -181,7 +188,6 @@ def llm_thread(
messages=formatted_messages, messages=formatted_messages,
model=model_name, # type: ignore model=model_name, # type: ignore
stream=stream, stream=stream,
stream_options={"include_usage": True} if stream else {},
temperature=temperature, temperature=temperature,
timeout=20, timeout=20,
**model_kwargs, **model_kwargs,

View file

@ -54,7 +54,7 @@ model_to_prompt_size = {
# OpenAI Models # OpenAI Models
"gpt-4o": 20000, "gpt-4o": 20000,
"gpt-4o-mini": 20000, "gpt-4o-mini": 20000,
"o1-preview": 20000, "o1": 20000,
"o1-mini": 20000, "o1-mini": 20000,
# Google Models # Google Models
"gemini-1.5-flash": 20000, "gemini-1.5-flash": 20000,

View file

@ -38,7 +38,7 @@ model_to_cost: Dict[str, Dict[str, float]] = {
# OpenAI Pricing: https://openai.com/api/pricing/ # OpenAI Pricing: https://openai.com/api/pricing/
"gpt-4o": {"input": 2.50, "output": 10.00}, "gpt-4o": {"input": 2.50, "output": 10.00},
"gpt-4o-mini": {"input": 0.15, "output": 0.60}, "gpt-4o-mini": {"input": 0.15, "output": 0.60},
"o1-preview": {"input": 15.0, "output": 60.00}, "o1": {"input": 15.0, "output": 60.00},
"o1-mini": {"input": 3.0, "output": 12.0}, "o1-mini": {"input": 3.0, "output": 12.0},
# Gemini Pricing: https://ai.google.dev/pricing # Gemini Pricing: https://ai.google.dev/pricing
"gemini-1.5-flash": {"input": 0.075, "output": 0.30}, "gemini-1.5-flash": {"input": 0.075, "output": 0.30},