Fix, improve args being passed to chat_completion args

- Allow passing completion args through completion_with_backoff
- Pass model_kwargs in a separate arg to simplify this
- Pass model in `model_name' kwarg from the send_message_to_model func
  `model_name' kwarg is used by langchain, not `model' kwarg
This commit is contained in:
Debanjum Singh Solanky 2024-04-26 11:36:31 +05:30
parent d8f2eac6e0
commit 346499f12c
2 changed files with 16 additions and 14 deletions

View file

@ -64,11 +64,12 @@ def extract_questions(
# Get Response from GPT # Get Response from GPT
response = completion_with_backoff( response = completion_with_backoff(
messages=messages, messages=messages,
model_name=model, completion_kwargs={"temperature": temperature, "max_tokens": max_tokens},
temperature=temperature, model_kwargs={
max_tokens=max_tokens, "model_name": model,
model_kwargs={"response_format": {"type": "json_object"}}, "openai_api_key": api_key,
openai_api_key=api_key, "model_kwargs": {"response_format": {"type": "json_object"}},
},
) )
# Extract, Clean Message from GPT's Response # Extract, Clean Message from GPT's Response
@ -96,9 +97,11 @@ def send_message_to_model(messages, api_key, model, response_type="text"):
# Get Response from GPT # Get Response from GPT
return completion_with_backoff( return completion_with_backoff(
messages=messages, messages=messages,
model=model, model_kwargs={
openai_api_key=api_key, "model_name": model,
model_kwargs={"response_format": {"type": response_type}}, "openai_api_key": api_key,
"model_kwargs": {"response_format": {"type": response_type}},
},
) )

View file

@ -43,13 +43,12 @@ class StreamingChatCallbackHandler(StreamingStdOutCallbackHandler):
before_sleep=before_sleep_log(logger, logging.DEBUG), before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True, reraise=True,
) )
def completion_with_backoff(**kwargs) -> str: def completion_with_backoff(messages, model_kwargs={}, completion_kwargs={}) -> str:
messages = kwargs.pop("messages") if not "openai_api_key" in model_kwargs:
if not "openai_api_key" in kwargs: model_kwargs["openai_api_key"] = os.getenv("OPENAI_API_KEY")
kwargs["openai_api_key"] = os.getenv("OPENAI_API_KEY") llm = ChatOpenAI(**model_kwargs, request_timeout=20, max_retries=1)
llm = ChatOpenAI(**kwargs, request_timeout=20, max_retries=1)
aggregated_response = "" aggregated_response = ""
for chunk in llm.stream(messages): for chunk in llm.stream(messages, **completion_kwargs):
aggregated_response += chunk.content aggregated_response += chunk.content
return aggregated_response return aggregated_response