Simplify argument names used in khoj openai completion functions

- Match argument names passed to khoj openai completion funcs with
  arguments passed to langchain calls to OpenAI
- This simplifies the logic in the khoj openai completion funcs
This commit is contained in:
Debanjum Singh Solanky 2023-05-31 10:59:31 +05:30
parent 703a7c89c0
commit ed4d0f9076
2 changed files with 11 additions and 11 deletions

View file

@ -27,7 +27,7 @@ def answer(text, user_query, model, api_key=None, temperature=0.5, max_tokens=50
logger.debug(f"Prompt for GPT: {prompt}")
response = completion_with_backoff(
prompt=prompt,
model=model,
model_name=model,
temperature=temperature,
max_tokens=max_tokens,
stop='"""',
@ -52,7 +52,7 @@ def summarize(text, summary_type, model, user_query=None, api_key=None, temperat
logger.debug(f"Prompt for GPT: {prompt}")
response = completion_with_backoff(
prompt=prompt,
model=model,
model_name=model,
temperature=temperature,
max_tokens=max_tokens,
frequency_penalty=0.2,
@ -96,7 +96,7 @@ def extract_questions(text, model="text-davinci-003", conversation_log={}, api_k
# Get Response from GPT
response = completion_with_backoff(
prompt=prompt,
model=model,
model_name=model,
temperature=temperature,
max_tokens=max_tokens,
stop=["A: ", "\n"],
@ -132,7 +132,7 @@ def extract_search_type(text, model, api_key=None, temperature=0.5, max_tokens=1
logger.debug(f"Prompt for GPT: {prompt}")
response = completion_with_backoff(
prompt=prompt,
model=model,
model_name=model,
temperature=temperature,
max_tokens=max_tokens,
frequency_penalty=0.2,
@ -174,9 +174,9 @@ def converse(references, user_query, conversation_log={}, model="gpt-3.5-turbo",
logger.debug(f"Conversation Context for GPT: {messages}")
response = chat_completion_with_backoff(
messages=messages,
model=model,
model_name=model,
temperature=temperature,
api_key=api_key,
openai_api_key=api_key,
)
# Extract, Clean Message from GPT's Response

View file

@ -41,7 +41,8 @@ max_prompt_size = {"gpt-3.5-turbo": 4096, "gpt-4": 8192}
)
def completion_with_backoff(**kwargs):
prompt = kwargs.pop("prompt")
kwargs["openai_api_key"] = kwargs["api_key"] if kwargs.get("api_key") else os.getenv("OPENAI_API_KEY")
if "openai_api_key" not in kwargs:
kwargs["openai_api_key"] = os.getenv("OPENAI_API_KEY")
llm = OpenAI(**kwargs, request_timeout=10, max_retries=1)
return llm(prompt)
@ -59,12 +60,11 @@ def completion_with_backoff(**kwargs):
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
def chat_completion_with_backoff(messages, model, temperature, **kwargs):
openai_api_key = kwargs["api_key"] if kwargs.get("api_key") else os.getenv("OPENAI_API_KEY")
def chat_completion_with_backoff(messages, model_name, temperature, openai_api_key=None):
chat = ChatOpenAI(
model_name=model,
model_name=model_name,
temperature=temperature,
openai_api_key=openai_api_key,
openai_api_key=openai_api_key or os.getenv("OPENAI_API_KEY"),
request_timeout=10,
max_retries=1,
)