mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 15:38:55 +01:00
Add retry logic to OpenAI API queries to increase Chat tenacity
- Move completion and chat_completion into helper methods under utils.py - Add retry with exponential backoff on OpenAI exceptions using tenacity package. This is officially suggested and used by other popular GPT based libraries
This commit is contained in:
parent
0aebf624fc
commit
67c850a4ac
3 changed files with 87 additions and 25 deletions
|
@ -42,6 +42,7 @@ dependencies = [
|
|||
"jinja2 == 3.1.2",
|
||||
"openai >= 0.27.0",
|
||||
"tiktoken >= 0.3.0",
|
||||
"tenacity >= 8.2.2",
|
||||
"pillow == 9.3.0",
|
||||
"pydantic == 1.9.1",
|
||||
"pyqt6 == 6.3.1",
|
||||
|
|
|
@ -1,15 +1,16 @@
|
|||
# Standard Packages
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
# External Packages
|
||||
import openai
|
||||
|
||||
# Internal Packages
|
||||
from khoj.utils.constants import empty_escape_sequences
|
||||
from khoj.processor.conversation.utils import message_to_prompt, generate_chatml_messages_with_context
|
||||
from khoj.processor.conversation.utils import (
|
||||
chat_completion_with_backoff,
|
||||
completion_with_backoff,
|
||||
message_to_prompt,
|
||||
generate_chatml_messages_with_context,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -19,9 +20,6 @@ def answer(text, user_query, model, api_key=None, temperature=0.5, max_tokens=50
|
|||
"""
|
||||
Answer user query using provided text as reference with OpenAI's GPT
|
||||
"""
|
||||
# Initialize Variables
|
||||
openai.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
||||
|
||||
# Setup Prompt based on Summary Type
|
||||
prompt = f"""
|
||||
You are a friendly, helpful personal assistant.
|
||||
|
@ -35,8 +33,13 @@ Question: {user_query}
|
|||
Answer (in second person):"""
|
||||
# Get Response from GPT
|
||||
logger.debug(f"Prompt for GPT: {prompt}")
|
||||
response = openai.Completion.create(
|
||||
prompt=prompt, model=model, temperature=temperature, max_tokens=max_tokens, stop='"""'
|
||||
response = completion_with_backoff(
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
stop='"""',
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
# Extract, Clean Message from GPT's Response
|
||||
|
@ -48,9 +51,6 @@ def summarize(text, summary_type, model, user_query=None, api_key=None, temperat
|
|||
"""
|
||||
Summarize user input using OpenAI's GPT
|
||||
"""
|
||||
# Initialize Variables
|
||||
openai.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
||||
|
||||
# Setup Prompt based on Summary Type
|
||||
if summary_type == "chat":
|
||||
prompt = f"""
|
||||
|
@ -69,8 +69,14 @@ Summarize the notes in second person perspective:"""
|
|||
|
||||
# Get Response from GPT
|
||||
logger.debug(f"Prompt for GPT: {prompt}")
|
||||
response = openai.Completion.create(
|
||||
prompt=prompt, model=model, temperature=temperature, max_tokens=max_tokens, frequency_penalty=0.2, stop='"""'
|
||||
response = completion_with_backoff(
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
frequency_penalty=0.2,
|
||||
stop='"""',
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
# Extract, Clean Message from GPT's Response
|
||||
|
@ -82,9 +88,6 @@ def extract_questions(text, model="text-davinci-003", conversation_log={}, api_k
|
|||
"""
|
||||
Infer search queries to retrieve relevant notes to answer user query
|
||||
"""
|
||||
# Initialize Variables
|
||||
openai.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
||||
|
||||
# Extract Past User Message and Inferred Questions from Conversation Log
|
||||
chat_history = "".join(
|
||||
[
|
||||
|
@ -158,8 +161,13 @@ Q: {text}
|
|||
"""
|
||||
|
||||
# Get Response from GPT
|
||||
response = openai.Completion.create(
|
||||
prompt=prompt, model=model, temperature=temperature, max_tokens=max_tokens, stop=["A: ", "\n"]
|
||||
response = completion_with_backoff(
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
stop=["A: ", "\n"],
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
# Extract, Clean Message from GPT's Response
|
||||
|
@ -184,7 +192,6 @@ def extract_search_type(text, model, api_key=None, temperature=0.5, max_tokens=1
|
|||
Extract search type from user query using OpenAI's GPT
|
||||
"""
|
||||
# Initialize Variables
|
||||
openai.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
||||
understand_primer = """
|
||||
Objective: Extract search type from user query and return information as JSON
|
||||
|
||||
|
@ -214,8 +221,14 @@ A:{ "search-type": "notes" }"""
|
|||
|
||||
# Get Response from GPT
|
||||
logger.debug(f"Prompt for GPT: {prompt}")
|
||||
response = openai.Completion.create(
|
||||
prompt=prompt, model=model, temperature=temperature, max_tokens=max_tokens, frequency_penalty=0.2, stop=["\n"]
|
||||
response = completion_with_backoff(
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
frequency_penalty=0.2,
|
||||
stop=["\n"],
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
# Extract, Clean Message from GPT's Response
|
||||
|
@ -229,7 +242,6 @@ def converse(references, user_query, conversation_log={}, api_key=None, temperat
|
|||
"""
|
||||
# Initialize Variables
|
||||
model = "gpt-3.5-turbo"
|
||||
openai.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
||||
compiled_references = "\n\n".join({f"# {item}" for item in references})
|
||||
|
||||
personality_primer = "You are Khoj, a friendly, smart and helpful personal assistant."
|
||||
|
@ -252,10 +264,11 @@ Question: {user_query}"""
|
|||
|
||||
# Get Response from GPT
|
||||
logger.debug(f"Conversation Context for GPT: {messages}")
|
||||
response = openai.ChatCompletion.create(
|
||||
response = chat_completion_with_backoff(
|
||||
messages=messages,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
# Extract, Clean Message from GPT's Response
|
||||
|
|
|
@ -1,16 +1,64 @@
|
|||
# Standard Packages
|
||||
import os
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
# External Packages
|
||||
import openai
|
||||
import tiktoken
|
||||
from tenacity import (
|
||||
before_sleep_log,
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
wait_random_exponential,
|
||||
)
|
||||
|
||||
# Internal Packages
|
||||
from khoj.utils.helpers import merge_dicts
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
max_prompt_size = {"gpt-3.5-turbo": 4096, "gpt-4": 8192}
|
||||
|
||||
|
||||
@retry(
|
||||
retry=(
|
||||
retry_if_exception_type(openai.error.Timeout)
|
||||
| retry_if_exception_type(openai.error.APIError)
|
||||
| retry_if_exception_type(openai.error.APIConnectionError)
|
||||
| retry_if_exception_type(openai.error.RateLimitError)
|
||||
| retry_if_exception_type(openai.error.ServiceUnavailableError)
|
||||
),
|
||||
wait=wait_random_exponential(min=1, max=30),
|
||||
stop=stop_after_attempt(6),
|
||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||
reraise=True,
|
||||
)
|
||||
def completion_with_backoff(**kwargs):
|
||||
openai.api_key = kwargs["api_key"] if kwargs.get("api_key") else os.getenv("OPENAI_API_KEY")
|
||||
return openai.Completion.create(**kwargs, request_timeout=60)
|
||||
|
||||
|
||||
@retry(
|
||||
retry=(
|
||||
retry_if_exception_type(openai.error.Timeout)
|
||||
| retry_if_exception_type(openai.error.APIError)
|
||||
| retry_if_exception_type(openai.error.APIConnectionError)
|
||||
| retry_if_exception_type(openai.error.RateLimitError)
|
||||
| retry_if_exception_type(openai.error.ServiceUnavailableError)
|
||||
),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
stop=stop_after_attempt(6),
|
||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||
reraise=True,
|
||||
)
|
||||
def chat_completion_with_backoff(**kwargs):
|
||||
openai.api_key = kwargs["api_key"] if kwargs.get("api_key") else os.getenv("OPENAI_API_KEY")
|
||||
return openai.ChatCompletion.create(**kwargs, request_timeout=60)
|
||||
|
||||
|
||||
def generate_chatml_messages_with_context(
|
||||
user_message, system_message, conversation_log={}, model_name="gpt-3.5-turbo", lookback_turns=2
|
||||
):
|
||||
|
|
Loading…
Reference in a new issue