From 346499f12c6fa03e49b09a35fb01c4863513bd42 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Fri, 26 Apr 2024 11:36:31 +0530 Subject: [PATCH] Fix, improve args being passed to chat_completion args - Allow passing completion args through completion_with_backoff - Pass model_kwargs in a separate arg to simplify this - Pass model in `model_name' kwarg from the send_message_to_model func `model_name' kwarg is used by langchain, not `model' kwarg --- src/khoj/processor/conversation/openai/gpt.py | 19 +++++++++++-------- .../processor/conversation/openai/utils.py | 11 +++++------ 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index e30ad5d5..c4bcd265 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -64,11 +64,12 @@ def extract_questions( # Get Response from GPT response = completion_with_backoff( messages=messages, - model_name=model, - temperature=temperature, - max_tokens=max_tokens, - model_kwargs={"response_format": {"type": "json_object"}}, - openai_api_key=api_key, + completion_kwargs={"temperature": temperature, "max_tokens": max_tokens}, + model_kwargs={ + "model_name": model, + "openai_api_key": api_key, + "model_kwargs": {"response_format": {"type": "json_object"}}, + }, ) # Extract, Clean Message from GPT's Response @@ -96,9 +97,11 @@ def send_message_to_model(messages, api_key, model, response_type="text"): # Get Response from GPT return completion_with_backoff( messages=messages, - model=model, - openai_api_key=api_key, - model_kwargs={"response_format": {"type": response_type}}, + model_kwargs={ + "model_name": model, + "openai_api_key": api_key, + "model_kwargs": {"response_format": {"type": response_type}}, + }, ) diff --git a/src/khoj/processor/conversation/openai/utils.py b/src/khoj/processor/conversation/openai/utils.py index 844a64b8..908a035d 100644 --- a/src/khoj/processor/conversation/openai/utils.py +++ b/src/khoj/processor/conversation/openai/utils.py @@ -43,13 +43,12 @@ class StreamingChatCallbackHandler(StreamingStdOutCallbackHandler): before_sleep=before_sleep_log(logger, logging.DEBUG), reraise=True, ) -def completion_with_backoff(**kwargs) -> str: - messages = kwargs.pop("messages") - if not "openai_api_key" in kwargs: - kwargs["openai_api_key"] = os.getenv("OPENAI_API_KEY") - llm = ChatOpenAI(**kwargs, request_timeout=20, max_retries=1) +def completion_with_backoff(messages, model_kwargs={}, completion_kwargs={}) -> str: + if not "openai_api_key" in model_kwargs: + model_kwargs["openai_api_key"] = os.getenv("OPENAI_API_KEY") + llm = ChatOpenAI(**model_kwargs, request_timeout=20, max_retries=1) aggregated_response = "" - for chunk in llm.stream(messages): + for chunk in llm.stream(messages, **completion_kwargs): aggregated_response += chunk.content return aggregated_response