mirror of
https://github.com/khoj-ai/khoj.git
synced 2025-02-17 08:04:21 +00:00
Add retry logic to OpenAI API queries to increase Chat tenacity
- Move completion and chat_completion into helper methods under utils.py - Add retry with exponential backoff on OpenAI exceptions using tenacity package. This is officially suggested and used by other popular GPT based libraries
This commit is contained in:
parent
0aebf624fc
commit
67c850a4ac
3 changed files with 87 additions and 25 deletions
|
@ -42,6 +42,7 @@ dependencies = [
|
||||||
"jinja2 == 3.1.2",
|
"jinja2 == 3.1.2",
|
||||||
"openai >= 0.27.0",
|
"openai >= 0.27.0",
|
||||||
"tiktoken >= 0.3.0",
|
"tiktoken >= 0.3.0",
|
||||||
|
"tenacity >= 8.2.2",
|
||||||
"pillow == 9.3.0",
|
"pillow == 9.3.0",
|
||||||
"pydantic == 1.9.1",
|
"pydantic == 1.9.1",
|
||||||
"pyqt6 == 6.3.1",
|
"pyqt6 == 6.3.1",
|
||||||
|
|
|
@ -1,15 +1,16 @@
|
||||||
# Standard Packages
|
# Standard Packages
|
||||||
import os
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
# External Packages
|
|
||||||
import openai
|
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from khoj.utils.constants import empty_escape_sequences
|
from khoj.utils.constants import empty_escape_sequences
|
||||||
from khoj.processor.conversation.utils import message_to_prompt, generate_chatml_messages_with_context
|
from khoj.processor.conversation.utils import (
|
||||||
|
chat_completion_with_backoff,
|
||||||
|
completion_with_backoff,
|
||||||
|
message_to_prompt,
|
||||||
|
generate_chatml_messages_with_context,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -19,9 +20,6 @@ def answer(text, user_query, model, api_key=None, temperature=0.5, max_tokens=50
|
||||||
"""
|
"""
|
||||||
Answer user query using provided text as reference with OpenAI's GPT
|
Answer user query using provided text as reference with OpenAI's GPT
|
||||||
"""
|
"""
|
||||||
# Initialize Variables
|
|
||||||
openai.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
|
||||||
|
|
||||||
# Setup Prompt based on Summary Type
|
# Setup Prompt based on Summary Type
|
||||||
prompt = f"""
|
prompt = f"""
|
||||||
You are a friendly, helpful personal assistant.
|
You are a friendly, helpful personal assistant.
|
||||||
|
@ -35,8 +33,13 @@ Question: {user_query}
|
||||||
Answer (in second person):"""
|
Answer (in second person):"""
|
||||||
# Get Response from GPT
|
# Get Response from GPT
|
||||||
logger.debug(f"Prompt for GPT: {prompt}")
|
logger.debug(f"Prompt for GPT: {prompt}")
|
||||||
response = openai.Completion.create(
|
response = completion_with_backoff(
|
||||||
prompt=prompt, model=model, temperature=temperature, max_tokens=max_tokens, stop='"""'
|
prompt=prompt,
|
||||||
|
model=model,
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
stop='"""',
|
||||||
|
api_key=api_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Extract, Clean Message from GPT's Response
|
# Extract, Clean Message from GPT's Response
|
||||||
|
@ -48,9 +51,6 @@ def summarize(text, summary_type, model, user_query=None, api_key=None, temperat
|
||||||
"""
|
"""
|
||||||
Summarize user input using OpenAI's GPT
|
Summarize user input using OpenAI's GPT
|
||||||
"""
|
"""
|
||||||
# Initialize Variables
|
|
||||||
openai.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
|
||||||
|
|
||||||
# Setup Prompt based on Summary Type
|
# Setup Prompt based on Summary Type
|
||||||
if summary_type == "chat":
|
if summary_type == "chat":
|
||||||
prompt = f"""
|
prompt = f"""
|
||||||
|
@ -69,8 +69,14 @@ Summarize the notes in second person perspective:"""
|
||||||
|
|
||||||
# Get Response from GPT
|
# Get Response from GPT
|
||||||
logger.debug(f"Prompt for GPT: {prompt}")
|
logger.debug(f"Prompt for GPT: {prompt}")
|
||||||
response = openai.Completion.create(
|
response = completion_with_backoff(
|
||||||
prompt=prompt, model=model, temperature=temperature, max_tokens=max_tokens, frequency_penalty=0.2, stop='"""'
|
prompt=prompt,
|
||||||
|
model=model,
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
frequency_penalty=0.2,
|
||||||
|
stop='"""',
|
||||||
|
api_key=api_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Extract, Clean Message from GPT's Response
|
# Extract, Clean Message from GPT's Response
|
||||||
|
@ -82,9 +88,6 @@ def extract_questions(text, model="text-davinci-003", conversation_log={}, api_k
|
||||||
"""
|
"""
|
||||||
Infer search queries to retrieve relevant notes to answer user query
|
Infer search queries to retrieve relevant notes to answer user query
|
||||||
"""
|
"""
|
||||||
# Initialize Variables
|
|
||||||
openai.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
|
||||||
|
|
||||||
# Extract Past User Message and Inferred Questions from Conversation Log
|
# Extract Past User Message and Inferred Questions from Conversation Log
|
||||||
chat_history = "".join(
|
chat_history = "".join(
|
||||||
[
|
[
|
||||||
|
@ -158,8 +161,13 @@ Q: {text}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Get Response from GPT
|
# Get Response from GPT
|
||||||
response = openai.Completion.create(
|
response = completion_with_backoff(
|
||||||
prompt=prompt, model=model, temperature=temperature, max_tokens=max_tokens, stop=["A: ", "\n"]
|
prompt=prompt,
|
||||||
|
model=model,
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
stop=["A: ", "\n"],
|
||||||
|
api_key=api_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Extract, Clean Message from GPT's Response
|
# Extract, Clean Message from GPT's Response
|
||||||
|
@ -184,7 +192,6 @@ def extract_search_type(text, model, api_key=None, temperature=0.5, max_tokens=1
|
||||||
Extract search type from user query using OpenAI's GPT
|
Extract search type from user query using OpenAI's GPT
|
||||||
"""
|
"""
|
||||||
# Initialize Variables
|
# Initialize Variables
|
||||||
openai.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
|
||||||
understand_primer = """
|
understand_primer = """
|
||||||
Objective: Extract search type from user query and return information as JSON
|
Objective: Extract search type from user query and return information as JSON
|
||||||
|
|
||||||
|
@ -214,8 +221,14 @@ A:{ "search-type": "notes" }"""
|
||||||
|
|
||||||
# Get Response from GPT
|
# Get Response from GPT
|
||||||
logger.debug(f"Prompt for GPT: {prompt}")
|
logger.debug(f"Prompt for GPT: {prompt}")
|
||||||
response = openai.Completion.create(
|
response = completion_with_backoff(
|
||||||
prompt=prompt, model=model, temperature=temperature, max_tokens=max_tokens, frequency_penalty=0.2, stop=["\n"]
|
prompt=prompt,
|
||||||
|
model=model,
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
frequency_penalty=0.2,
|
||||||
|
stop=["\n"],
|
||||||
|
api_key=api_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Extract, Clean Message from GPT's Response
|
# Extract, Clean Message from GPT's Response
|
||||||
|
@ -229,7 +242,6 @@ def converse(references, user_query, conversation_log={}, api_key=None, temperat
|
||||||
"""
|
"""
|
||||||
# Initialize Variables
|
# Initialize Variables
|
||||||
model = "gpt-3.5-turbo"
|
model = "gpt-3.5-turbo"
|
||||||
openai.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
|
||||||
compiled_references = "\n\n".join({f"# {item}" for item in references})
|
compiled_references = "\n\n".join({f"# {item}" for item in references})
|
||||||
|
|
||||||
personality_primer = "You are Khoj, a friendly, smart and helpful personal assistant."
|
personality_primer = "You are Khoj, a friendly, smart and helpful personal assistant."
|
||||||
|
@ -252,10 +264,11 @@ Question: {user_query}"""
|
||||||
|
|
||||||
# Get Response from GPT
|
# Get Response from GPT
|
||||||
logger.debug(f"Conversation Context for GPT: {messages}")
|
logger.debug(f"Conversation Context for GPT: {messages}")
|
||||||
response = openai.ChatCompletion.create(
|
response = chat_completion_with_backoff(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
model=model,
|
model=model,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
|
api_key=api_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Extract, Clean Message from GPT's Response
|
# Extract, Clean Message from GPT's Response
|
||||||
|
|
|
@ -1,16 +1,64 @@
|
||||||
# Standard Packages
|
# Standard Packages
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
# External Packages
|
# External Packages
|
||||||
|
import openai
|
||||||
import tiktoken
|
import tiktoken
|
||||||
|
from tenacity import (
|
||||||
|
before_sleep_log,
|
||||||
|
retry,
|
||||||
|
retry_if_exception_type,
|
||||||
|
stop_after_attempt,
|
||||||
|
wait_exponential,
|
||||||
|
wait_random_exponential,
|
||||||
|
)
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from khoj.utils.helpers import merge_dicts
|
from khoj.utils.helpers import merge_dicts
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
max_prompt_size = {"gpt-3.5-turbo": 4096, "gpt-4": 8192}
|
max_prompt_size = {"gpt-3.5-turbo": 4096, "gpt-4": 8192}
|
||||||
|
|
||||||
|
|
||||||
|
@retry(
|
||||||
|
retry=(
|
||||||
|
retry_if_exception_type(openai.error.Timeout)
|
||||||
|
| retry_if_exception_type(openai.error.APIError)
|
||||||
|
| retry_if_exception_type(openai.error.APIConnectionError)
|
||||||
|
| retry_if_exception_type(openai.error.RateLimitError)
|
||||||
|
| retry_if_exception_type(openai.error.ServiceUnavailableError)
|
||||||
|
),
|
||||||
|
wait=wait_random_exponential(min=1, max=30),
|
||||||
|
stop=stop_after_attempt(6),
|
||||||
|
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||||
|
reraise=True,
|
||||||
|
)
|
||||||
|
def completion_with_backoff(**kwargs):
|
||||||
|
openai.api_key = kwargs["api_key"] if kwargs.get("api_key") else os.getenv("OPENAI_API_KEY")
|
||||||
|
return openai.Completion.create(**kwargs, request_timeout=60)
|
||||||
|
|
||||||
|
|
||||||
|
@retry(
|
||||||
|
retry=(
|
||||||
|
retry_if_exception_type(openai.error.Timeout)
|
||||||
|
| retry_if_exception_type(openai.error.APIError)
|
||||||
|
| retry_if_exception_type(openai.error.APIConnectionError)
|
||||||
|
| retry_if_exception_type(openai.error.RateLimitError)
|
||||||
|
| retry_if_exception_type(openai.error.ServiceUnavailableError)
|
||||||
|
),
|
||||||
|
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||||
|
stop=stop_after_attempt(6),
|
||||||
|
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||||
|
reraise=True,
|
||||||
|
)
|
||||||
|
def chat_completion_with_backoff(**kwargs):
|
||||||
|
openai.api_key = kwargs["api_key"] if kwargs.get("api_key") else os.getenv("OPENAI_API_KEY")
|
||||||
|
return openai.ChatCompletion.create(**kwargs, request_timeout=60)
|
||||||
|
|
||||||
|
|
||||||
def generate_chatml_messages_with_context(
|
def generate_chatml_messages_with_context(
|
||||||
user_message, system_message, conversation_log={}, model_name="gpt-3.5-turbo", lookback_turns=2
|
user_message, system_message, conversation_log={}, model_name="gpt-3.5-turbo", lookback_turns=2
|
||||||
):
|
):
|
||||||
|
|
Loading…
Add table
Reference in a new issue