diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index e30ad5d5..c4bcd265 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -64,11 +64,12 @@ def extract_questions( # Get Response from GPT response = completion_with_backoff( messages=messages, - model_name=model, - temperature=temperature, - max_tokens=max_tokens, - model_kwargs={"response_format": {"type": "json_object"}}, - openai_api_key=api_key, + completion_kwargs={"temperature": temperature, "max_tokens": max_tokens}, + model_kwargs={ + "model_name": model, + "openai_api_key": api_key, + "model_kwargs": {"response_format": {"type": "json_object"}}, + }, ) # Extract, Clean Message from GPT's Response @@ -96,9 +97,11 @@ def send_message_to_model(messages, api_key, model, response_type="text"): # Get Response from GPT return completion_with_backoff( messages=messages, - model=model, - openai_api_key=api_key, - model_kwargs={"response_format": {"type": response_type}}, + model_kwargs={ + "model_name": model, + "openai_api_key": api_key, + "model_kwargs": {"response_format": {"type": response_type}}, + }, ) diff --git a/src/khoj/processor/conversation/openai/utils.py b/src/khoj/processor/conversation/openai/utils.py index 844a64b8..908a035d 100644 --- a/src/khoj/processor/conversation/openai/utils.py +++ b/src/khoj/processor/conversation/openai/utils.py @@ -43,13 +43,12 @@ class StreamingChatCallbackHandler(StreamingStdOutCallbackHandler): before_sleep=before_sleep_log(logger, logging.DEBUG), reraise=True, ) -def completion_with_backoff(**kwargs) -> str: - messages = kwargs.pop("messages") - if not "openai_api_key" in kwargs: - kwargs["openai_api_key"] = os.getenv("OPENAI_API_KEY") - llm = ChatOpenAI(**kwargs, request_timeout=20, max_retries=1) +def completion_with_backoff(messages, model_kwargs={}, completion_kwargs={}) -> str: + if not "openai_api_key" in model_kwargs: + model_kwargs["openai_api_key"] = os.getenv("OPENAI_API_KEY") + llm = ChatOpenAI(**model_kwargs, request_timeout=20, max_retries=1) aggregated_response = "" - for chunk in llm.stream(messages): + for chunk in llm.stream(messages, **completion_kwargs): aggregated_response += chunk.content return aggregated_response