Add retry logic to OpenAI API queries to increase Chat tenacity

- Move completion and chat_completion into helper methods under utils.py
- Add retry with exponential backoff on OpenAI exceptions using
  tenacity package. This is officially suggested and used by other
  popular GPT based libraries
This commit is contained in:
Debanjum Singh Solanky 2023-03-25 12:36:51 +07:00
parent 0aebf624fc
commit 67c850a4ac
3 changed files with 87 additions and 25 deletions

View file

@ -42,6 +42,7 @@ dependencies = [
"jinja2 == 3.1.2",
"openai >= 0.27.0",
"tiktoken >= 0.3.0",
"tenacity >= 8.2.2",
"pillow == 9.3.0",
"pydantic == 1.9.1",
"pyqt6 == 6.3.1",

View file

@ -1,15 +1,16 @@
# Standard Packages
import os
import json
import logging
from datetime import datetime
# External Packages
import openai
# Internal Packages
from khoj.utils.constants import empty_escape_sequences
from khoj.processor.conversation.utils import message_to_prompt, generate_chatml_messages_with_context
from khoj.processor.conversation.utils import (
chat_completion_with_backoff,
completion_with_backoff,
message_to_prompt,
generate_chatml_messages_with_context,
)
logger = logging.getLogger(__name__)
@ -19,9 +20,6 @@ def answer(text, user_query, model, api_key=None, temperature=0.5, max_tokens=50
"""
Answer user query using provided text as reference with OpenAI's GPT
"""
# Initialize Variables
openai.api_key = api_key or os.getenv("OPENAI_API_KEY")
# Setup Prompt based on Summary Type
prompt = f"""
You are a friendly, helpful personal assistant.
@ -35,8 +33,13 @@ Question: {user_query}
Answer (in second person):"""
# Get Response from GPT
logger.debug(f"Prompt for GPT: {prompt}")
response = openai.Completion.create(
prompt=prompt, model=model, temperature=temperature, max_tokens=max_tokens, stop='"""'
response = completion_with_backoff(
prompt=prompt,
model=model,
temperature=temperature,
max_tokens=max_tokens,
stop='"""',
api_key=api_key,
)
# Extract, Clean Message from GPT's Response
@ -48,9 +51,6 @@ def summarize(text, summary_type, model, user_query=None, api_key=None, temperat
"""
Summarize user input using OpenAI's GPT
"""
# Initialize Variables
openai.api_key = api_key or os.getenv("OPENAI_API_KEY")
# Setup Prompt based on Summary Type
if summary_type == "chat":
prompt = f"""
@ -69,8 +69,14 @@ Summarize the notes in second person perspective:"""
# Get Response from GPT
logger.debug(f"Prompt for GPT: {prompt}")
response = openai.Completion.create(
prompt=prompt, model=model, temperature=temperature, max_tokens=max_tokens, frequency_penalty=0.2, stop='"""'
response = completion_with_backoff(
prompt=prompt,
model=model,
temperature=temperature,
max_tokens=max_tokens,
frequency_penalty=0.2,
stop='"""',
api_key=api_key,
)
# Extract, Clean Message from GPT's Response
@ -82,9 +88,6 @@ def extract_questions(text, model="text-davinci-003", conversation_log={}, api_k
"""
Infer search queries to retrieve relevant notes to answer user query
"""
# Initialize Variables
openai.api_key = api_key or os.getenv("OPENAI_API_KEY")
# Extract Past User Message and Inferred Questions from Conversation Log
chat_history = "".join(
[
@ -158,8 +161,13 @@ Q: {text}
"""
# Get Response from GPT
response = openai.Completion.create(
prompt=prompt, model=model, temperature=temperature, max_tokens=max_tokens, stop=["A: ", "\n"]
response = completion_with_backoff(
prompt=prompt,
model=model,
temperature=temperature,
max_tokens=max_tokens,
stop=["A: ", "\n"],
api_key=api_key,
)
# Extract, Clean Message from GPT's Response
@ -184,7 +192,6 @@ def extract_search_type(text, model, api_key=None, temperature=0.5, max_tokens=1
Extract search type from user query using OpenAI's GPT
"""
# Initialize Variables
openai.api_key = api_key or os.getenv("OPENAI_API_KEY")
understand_primer = """
Objective: Extract search type from user query and return information as JSON
@ -214,8 +221,14 @@ A:{ "search-type": "notes" }"""
# Get Response from GPT
logger.debug(f"Prompt for GPT: {prompt}")
response = openai.Completion.create(
prompt=prompt, model=model, temperature=temperature, max_tokens=max_tokens, frequency_penalty=0.2, stop=["\n"]
response = completion_with_backoff(
prompt=prompt,
model=model,
temperature=temperature,
max_tokens=max_tokens,
frequency_penalty=0.2,
stop=["\n"],
api_key=api_key,
)
# Extract, Clean Message from GPT's Response
@ -229,7 +242,6 @@ def converse(references, user_query, conversation_log={}, api_key=None, temperat
"""
# Initialize Variables
model = "gpt-3.5-turbo"
openai.api_key = api_key or os.getenv("OPENAI_API_KEY")
compiled_references = "\n\n".join({f"# {item}" for item in references})
personality_primer = "You are Khoj, a friendly, smart and helpful personal assistant."
@ -252,10 +264,11 @@ Question: {user_query}"""
# Get Response from GPT
logger.debug(f"Conversation Context for GPT: {messages}")
response = openai.ChatCompletion.create(
response = chat_completion_with_backoff(
messages=messages,
model=model,
temperature=temperature,
api_key=api_key,
)
# Extract, Clean Message from GPT's Response

View file

@ -1,16 +1,64 @@
# Standard Packages
import os
import logging
from datetime import datetime
# External Packages
import openai
import tiktoken
from tenacity import (
before_sleep_log,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
wait_random_exponential,
)
# Internal Packages
from khoj.utils.helpers import merge_dicts
logger = logging.getLogger(__name__)
max_prompt_size = {"gpt-3.5-turbo": 4096, "gpt-4": 8192}
@retry(
retry=(
retry_if_exception_type(openai.error.Timeout)
| retry_if_exception_type(openai.error.APIError)
| retry_if_exception_type(openai.error.APIConnectionError)
| retry_if_exception_type(openai.error.RateLimitError)
| retry_if_exception_type(openai.error.ServiceUnavailableError)
),
wait=wait_random_exponential(min=1, max=30),
stop=stop_after_attempt(6),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
def completion_with_backoff(**kwargs):
openai.api_key = kwargs["api_key"] if kwargs.get("api_key") else os.getenv("OPENAI_API_KEY")
return openai.Completion.create(**kwargs, request_timeout=60)
@retry(
retry=(
retry_if_exception_type(openai.error.Timeout)
| retry_if_exception_type(openai.error.APIError)
| retry_if_exception_type(openai.error.APIConnectionError)
| retry_if_exception_type(openai.error.RateLimitError)
| retry_if_exception_type(openai.error.ServiceUnavailableError)
),
wait=wait_exponential(multiplier=1, min=4, max=10),
stop=stop_after_attempt(6),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
def chat_completion_with_backoff(**kwargs):
openai.api_key = kwargs["api_key"] if kwargs.get("api_key") else os.getenv("OPENAI_API_KEY")
return openai.ChatCompletion.create(**kwargs, request_timeout=60)
def generate_chatml_messages_with_context(
user_message, system_message, conversation_log={}, model_name="gpt-3.5-turbo", lookback_turns=2
):