From 37ae48d9cf74cbe5894b79e7c01d84fdcb91121c Mon Sep 17 00:00:00 2001 From: Debanjum Date: Tue, 17 Dec 2024 23:22:42 -0800 Subject: [PATCH] Add support for OpenAI o1 model Needed to handle the current API limitations of the o1 model. Specifically its inability to stream responses --- .../processor/conversation/openai/utils.py | 22 ++++++++++++------- src/khoj/processor/conversation/utils.py | 2 +- src/khoj/utils/constants.py | 2 +- 3 files changed, 16 insertions(+), 10 deletions(-) diff --git a/src/khoj/processor/conversation/openai/utils.py b/src/khoj/processor/conversation/openai/utils.py index 82f68259..8af836f1 100644 --- a/src/khoj/processor/conversation/openai/utils.py +++ b/src/khoj/processor/conversation/openai/utils.py @@ -58,13 +58,17 @@ def completion_with_backoff( openai_clients[client_key] = client formatted_messages = [{"role": message.role, "content": message.content} for message in messages] - stream = True # Update request parameters for compatability with o1 model series # Refer: https://platform.openai.com/docs/guides/reasoning/beta-limitations - if model_name.startswith("o1"): + stream = True + model_kwargs["stream_options"] = {"include_usage": True} + if model_name == "o1": + temperature = 1 + stream = False + model_kwargs.pop("stream_options", None) + elif model_name.startswith("o1"): temperature = 1 - model_kwargs.pop("stop", None) model_kwargs.pop("response_format", None) if os.getenv("KHOJ_LLM_SEED"): @@ -74,7 +78,6 @@ def completion_with_backoff( messages=formatted_messages, # type: ignore model=model_name, # type: ignore stream=stream, - stream_options={"include_usage": True} if stream else {}, temperature=temperature, timeout=20, **model_kwargs, @@ -165,13 +168,17 @@ def llm_thread( client = openai_clients[client_key] formatted_messages = [{"role": message.role, "content": message.content} for message in messages] - stream = True # Update request parameters for compatability with o1 model series # Refer: https://platform.openai.com/docs/guides/reasoning/beta-limitations - if model_name.startswith("o1"): + stream = True + model_kwargs["stream_options"] = {"include_usage": True} + if model_name == "o1": + temperature = 1 + stream = False + model_kwargs.pop("stream_options", None) + elif model_name.startswith("o1-"): temperature = 1 - model_kwargs.pop("stop", None) model_kwargs.pop("response_format", None) if os.getenv("KHOJ_LLM_SEED"): @@ -181,7 +188,6 @@ def llm_thread( messages=formatted_messages, model=model_name, # type: ignore stream=stream, - stream_options={"include_usage": True} if stream else {}, temperature=temperature, timeout=20, **model_kwargs, diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 64a46efc..1374be08 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -54,7 +54,7 @@ model_to_prompt_size = { # OpenAI Models "gpt-4o": 20000, "gpt-4o-mini": 20000, - "o1-preview": 20000, + "o1": 20000, "o1-mini": 20000, # Google Models "gemini-1.5-flash": 20000, diff --git a/src/khoj/utils/constants.py b/src/khoj/utils/constants.py index f2ab40c6..59534895 100644 --- a/src/khoj/utils/constants.py +++ b/src/khoj/utils/constants.py @@ -38,7 +38,7 @@ model_to_cost: Dict[str, Dict[str, float]] = { # OpenAI Pricing: https://openai.com/api/pricing/ "gpt-4o": {"input": 2.50, "output": 10.00}, "gpt-4o-mini": {"input": 0.15, "output": 0.60}, - "o1-preview": {"input": 15.0, "output": 60.00}, + "o1": {"input": 15.0, "output": 60.00}, "o1-mini": {"input": 3.0, "output": 12.0}, # Gemini Pricing: https://ai.google.dev/pricing "gemini-1.5-flash": {"input": 0.075, "output": 0.30},