Fix, improve args being passed to chat_completion args

- Allow passing completion args through completion_with_backoff
- Pass model_kwargs in a separate arg to simplify this
- Pass model in `model_name' kwarg from the send_message_to_model func
  `model_name' kwarg is used by langchain, not `model' kwarg
This commit is contained in:
Debanjum Singh Solanky 2024-04-26 11:36:31 +05:30
parent d8f2eac6e0
commit 346499f12c
2 changed files with 16 additions and 14 deletions

View file

@ -64,11 +64,12 @@ def extract_questions(
# Get Response from GPT
response = completion_with_backoff(
messages=messages,
model_name=model,
temperature=temperature,
max_tokens=max_tokens,
model_kwargs={"response_format": {"type": "json_object"}},
openai_api_key=api_key,
completion_kwargs={"temperature": temperature, "max_tokens": max_tokens},
model_kwargs={
"model_name": model,
"openai_api_key": api_key,
"model_kwargs": {"response_format": {"type": "json_object"}},
},
)
# Extract, Clean Message from GPT's Response
@ -96,9 +97,11 @@ def send_message_to_model(messages, api_key, model, response_type="text"):
# Get Response from GPT
return completion_with_backoff(
messages=messages,
model=model,
openai_api_key=api_key,
model_kwargs={"response_format": {"type": response_type}},
model_kwargs={
"model_name": model,
"openai_api_key": api_key,
"model_kwargs": {"response_format": {"type": response_type}},
},
)

View file

@ -43,13 +43,12 @@ class StreamingChatCallbackHandler(StreamingStdOutCallbackHandler):
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
def completion_with_backoff(**kwargs) -> str:
messages = kwargs.pop("messages")
if not "openai_api_key" in kwargs:
kwargs["openai_api_key"] = os.getenv("OPENAI_API_KEY")
llm = ChatOpenAI(**kwargs, request_timeout=20, max_retries=1)
def completion_with_backoff(messages, model_kwargs={}, completion_kwargs={}) -> str:
if not "openai_api_key" in model_kwargs:
model_kwargs["openai_api_key"] = os.getenv("OPENAI_API_KEY")
llm = ChatOpenAI(**model_kwargs, request_timeout=20, max_retries=1)
aggregated_response = ""
for chunk in llm.stream(messages):
for chunk in llm.stream(messages, **completion_kwargs):
aggregated_response += chunk.content
return aggregated_response