mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-12-22 20:28:09 +00:00
Add support for OpenAI o1 model
Needed to handle the current API limitations of the o1 model. Specifically its inability to stream responses
This commit is contained in:
parent
b1c5c5bcc9
commit
37ae48d9cf
3 changed files with 16 additions and 10 deletions
|
@ -58,13 +58,17 @@ def completion_with_backoff(
|
|||
openai_clients[client_key] = client
|
||||
|
||||
formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
|
||||
stream = True
|
||||
|
||||
# Update request parameters for compatability with o1 model series
|
||||
# Refer: https://platform.openai.com/docs/guides/reasoning/beta-limitations
|
||||
if model_name.startswith("o1"):
|
||||
stream = True
|
||||
model_kwargs["stream_options"] = {"include_usage": True}
|
||||
if model_name == "o1":
|
||||
temperature = 1
|
||||
stream = False
|
||||
model_kwargs.pop("stream_options", None)
|
||||
elif model_name.startswith("o1"):
|
||||
temperature = 1
|
||||
model_kwargs.pop("stop", None)
|
||||
model_kwargs.pop("response_format", None)
|
||||
|
||||
if os.getenv("KHOJ_LLM_SEED"):
|
||||
|
@ -74,7 +78,6 @@ def completion_with_backoff(
|
|||
messages=formatted_messages, # type: ignore
|
||||
model=model_name, # type: ignore
|
||||
stream=stream,
|
||||
stream_options={"include_usage": True} if stream else {},
|
||||
temperature=temperature,
|
||||
timeout=20,
|
||||
**model_kwargs,
|
||||
|
@ -165,13 +168,17 @@ def llm_thread(
|
|||
client = openai_clients[client_key]
|
||||
|
||||
formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
|
||||
stream = True
|
||||
|
||||
# Update request parameters for compatability with o1 model series
|
||||
# Refer: https://platform.openai.com/docs/guides/reasoning/beta-limitations
|
||||
if model_name.startswith("o1"):
|
||||
stream = True
|
||||
model_kwargs["stream_options"] = {"include_usage": True}
|
||||
if model_name == "o1":
|
||||
temperature = 1
|
||||
stream = False
|
||||
model_kwargs.pop("stream_options", None)
|
||||
elif model_name.startswith("o1-"):
|
||||
temperature = 1
|
||||
model_kwargs.pop("stop", None)
|
||||
model_kwargs.pop("response_format", None)
|
||||
|
||||
if os.getenv("KHOJ_LLM_SEED"):
|
||||
|
@ -181,7 +188,6 @@ def llm_thread(
|
|||
messages=formatted_messages,
|
||||
model=model_name, # type: ignore
|
||||
stream=stream,
|
||||
stream_options={"include_usage": True} if stream else {},
|
||||
temperature=temperature,
|
||||
timeout=20,
|
||||
**model_kwargs,
|
||||
|
|
|
@ -54,7 +54,7 @@ model_to_prompt_size = {
|
|||
# OpenAI Models
|
||||
"gpt-4o": 20000,
|
||||
"gpt-4o-mini": 20000,
|
||||
"o1-preview": 20000,
|
||||
"o1": 20000,
|
||||
"o1-mini": 20000,
|
||||
# Google Models
|
||||
"gemini-1.5-flash": 20000,
|
||||
|
|
|
@ -38,7 +38,7 @@ model_to_cost: Dict[str, Dict[str, float]] = {
|
|||
# OpenAI Pricing: https://openai.com/api/pricing/
|
||||
"gpt-4o": {"input": 2.50, "output": 10.00},
|
||||
"gpt-4o-mini": {"input": 0.15, "output": 0.60},
|
||||
"o1-preview": {"input": 15.0, "output": 60.00},
|
||||
"o1": {"input": 15.0, "output": 60.00},
|
||||
"o1-mini": {"input": 3.0, "output": 12.0},
|
||||
# Gemini Pricing: https://ai.google.dev/pricing
|
||||
"gemini-1.5-flash": {"input": 0.075, "output": 0.30},
|
||||
|
|
Loading…
Reference in a new issue