mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 23:48:56 +01:00
Fix, improve args being passed to chat_completion args
- Allow passing completion args through completion_with_backoff - Pass model_kwargs in a separate arg to simplify this - Pass model in `model_name' kwarg from the send_message_to_model func `model_name' kwarg is used by langchain, not `model' kwarg
This commit is contained in:
parent
d8f2eac6e0
commit
346499f12c
2 changed files with 16 additions and 14 deletions
|
@ -64,11 +64,12 @@ def extract_questions(
|
|||
# Get Response from GPT
|
||||
response = completion_with_backoff(
|
||||
messages=messages,
|
||||
model_name=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
model_kwargs={"response_format": {"type": "json_object"}},
|
||||
openai_api_key=api_key,
|
||||
completion_kwargs={"temperature": temperature, "max_tokens": max_tokens},
|
||||
model_kwargs={
|
||||
"model_name": model,
|
||||
"openai_api_key": api_key,
|
||||
"model_kwargs": {"response_format": {"type": "json_object"}},
|
||||
},
|
||||
)
|
||||
|
||||
# Extract, Clean Message from GPT's Response
|
||||
|
@ -96,9 +97,11 @@ def send_message_to_model(messages, api_key, model, response_type="text"):
|
|||
# Get Response from GPT
|
||||
return completion_with_backoff(
|
||||
messages=messages,
|
||||
model=model,
|
||||
openai_api_key=api_key,
|
||||
model_kwargs={"response_format": {"type": response_type}},
|
||||
model_kwargs={
|
||||
"model_name": model,
|
||||
"openai_api_key": api_key,
|
||||
"model_kwargs": {"response_format": {"type": response_type}},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -43,13 +43,12 @@ class StreamingChatCallbackHandler(StreamingStdOutCallbackHandler):
|
|||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||
reraise=True,
|
||||
)
|
||||
def completion_with_backoff(**kwargs) -> str:
|
||||
messages = kwargs.pop("messages")
|
||||
if not "openai_api_key" in kwargs:
|
||||
kwargs["openai_api_key"] = os.getenv("OPENAI_API_KEY")
|
||||
llm = ChatOpenAI(**kwargs, request_timeout=20, max_retries=1)
|
||||
def completion_with_backoff(messages, model_kwargs={}, completion_kwargs={}) -> str:
|
||||
if not "openai_api_key" in model_kwargs:
|
||||
model_kwargs["openai_api_key"] = os.getenv("OPENAI_API_KEY")
|
||||
llm = ChatOpenAI(**model_kwargs, request_timeout=20, max_retries=1)
|
||||
aggregated_response = ""
|
||||
for chunk in llm.stream(messages):
|
||||
for chunk in llm.stream(messages, **completion_kwargs):
|
||||
aggregated_response += chunk.content
|
||||
return aggregated_response
|
||||
|
||||
|
|
Loading…
Reference in a new issue