Move remaining chat actors to use OpenAI chat models

- Deprecate the unused beta /answer and /search type identification endpoints and associated GPT functions
- Update extract_questions to use GPT4
- Update summarize method to default to GPT-3.5
- Update date filter to support quoting values in single quotes too. So now both dt>'2023-04-01' and dt>"2023-04-01" should work
- Remove "model" field from chat settings on the web interface
This commit is contained in:
Debanjum 2023-07-07 18:53:05 -07:00 committed by GitHub
commit 4b79d8216f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 49 additions and 174 deletions

View file

@ -205,7 +205,7 @@ def configure_conversation_processor(conversation_processor_config):
else: else:
# Initialize Conversation Logs # Initialize Conversation Logs
conversation_processor.meta_log = {} conversation_processor.meta_log = {}
conversation_processor.chat_session = "" conversation_processor.chat_session = []
return conversation_processor return conversation_processor
@ -225,9 +225,9 @@ def save_chat_session():
chat_session = state.processor_config.conversation.chat_session chat_session = state.processor_config.conversation.chat_session
openai_api_key = state.processor_config.conversation.openai_api_key openai_api_key = state.processor_config.conversation.openai_api_key
conversation_log = state.processor_config.conversation.meta_log conversation_log = state.processor_config.conversation.meta_log
model = state.processor_config.conversation.model chat_model = state.processor_config.conversation.chat_model
session = { session = {
"summary": summarize(chat_session, summary_type="chat", model=model, api_key=openai_api_key), "summary": summarize(chat_session, model=chat_model, api_key=openai_api_key),
"session-start": conversation_log.get("session", [{"session-end": 0}])[-1]["session-end"], "session-start": conversation_log.get("session", [{"session-end": 0}])[-1]["session-end"],
"session-end": len(conversation_log["chat"]), "session-end": len(conversation_log["chat"]),
} }
@ -242,7 +242,7 @@ def save_chat_session():
with open(conversation_logfile, "w+", encoding="utf-8") as logfile: with open(conversation_logfile, "w+", encoding="utf-8") as logfile:
json.dump(conversation_log, logfile, indent=2) json.dump(conversation_log, logfile, indent=2)
state.processor_config.conversation.chat_session = None state.processor_config.conversation.chat_session = []
logger.info("📩 Saved current chat session to conversation logs") logger.info("📩 Saved current chat session to conversation logs")

View file

@ -16,9 +16,17 @@
<input type="text" id="openai-api-key" name="openai-api-key" value="{{ current_config['openai_api_key'] }}"> <input type="text" id="openai-api-key" name="openai-api-key" value="{{ current_config['openai_api_key'] }}">
</td> </td>
</tr> </tr>
<tr>
<td>
<label for="chat-model">Chat Model</label>
</td>
<td>
<input type="text" id="chat-model" name="chat-model" value="{{ current_config['chat_model'] }}">
</td>
</tr>
</table> </table>
<table > <table style="display: none;">
<tr style="display: none;"> <tr>
<td> <td>
<label for="conversation-logfile">Conversation Logfile</label> <label for="conversation-logfile">Conversation Logfile</label>
</td> </td>
@ -34,14 +42,6 @@
<input type="text" id="model" name="model" value="{{ current_config['model'] }}"> <input type="text" id="model" name="model" value="{{ current_config['model'] }}">
</td> </td>
</tr> </tr>
<tr>
<td>
<label for="chat-model">Chat Model</label>
</td>
<td>
<input type="text" id="chat-model" name="chat-model" value="{{ current_config['chat_model'] }}">
</td>
</tr>
</table> </table>
<div class="section"> <div class="section">
<div id="success" style="display: none;" ></div> <div id="success" style="display: none;" ></div>

View file

@ -1,9 +1,11 @@
# Standard Packages # Standard Packages
import json
import logging import logging
from datetime import datetime from datetime import datetime
from typing import Optional from typing import Optional
# External Packages
from langchain.schema import ChatMessage
# Internal Packages # Internal Packages
from khoj.utils.constants import empty_escape_sequences from khoj.utils.constants import empty_escape_sequences
from khoj.processor.conversation import prompts from khoj.processor.conversation import prompts
@ -17,44 +19,16 @@ from khoj.processor.conversation.utils import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def answer(text, user_query, model, api_key=None, temperature=0.5, max_tokens=500): def summarize(session, model, api_key=None, temperature=0.5, max_tokens=200):
""" """
Answer user query using provided text as reference with OpenAI's GPT Summarize conversation session using the specified OpenAI chat model
""" """
# Setup Prompt from arguments messages = [ChatMessage(content=prompts.summarize_chat.format(), role="system")] + session
prompt = prompts.answer.format(text=text, user_query=user_query)
# Get Response from GPT # Get Response from GPT
logger.debug(f"Prompt for GPT: {prompt}") logger.debug(f"Prompt for GPT: {messages}")
response = completion_with_backoff( response = completion_with_backoff(
prompt=prompt, messages=messages,
model_name=model,
temperature=temperature,
max_tokens=max_tokens,
model_kwargs={"stop": ['"""']},
openai_api_key=api_key,
)
# Extract, Clean Message from GPT's Response
return str(response).replace("\n\n", "")
def summarize(text, summary_type, model, user_query=None, api_key=None, temperature=0.5, max_tokens=200):
"""
Summarize user input using OpenAI's GPT
"""
# Setup Prompt based on Summary Type
if summary_type == "chat":
prompt = prompts.summarize_chat.format(text=text)
elif summary_type == "notes":
prompt = prompts.summarize_notes.format(text=text, user_query=user_query)
else:
raise ValueError(f"Invalid summary type: {summary_type}")
# Get Response from GPT
logger.debug(f"Prompt for GPT: {prompt}")
response = completion_with_backoff(
prompt=prompt,
model_name=model, model_name=model,
temperature=temperature, temperature=temperature,
max_tokens=max_tokens, max_tokens=max_tokens,
@ -64,11 +38,11 @@ def summarize(text, summary_type, model, user_query=None, api_key=None, temperat
) )
# Extract, Clean Message from GPT's Response # Extract, Clean Message from GPT's Response
return str(response).replace("\n\n", "") return str(response.content).replace("\n\n", "")
def extract_questions( def extract_questions(
text, model: Optional[str] = "text-davinci-003", conversation_log={}, api_key=None, temperature=0, max_tokens=100 text, model: Optional[str] = "gpt-4", conversation_log={}, api_key=None, temperature=0, max_tokens=100
): ):
""" """
Infer search queries to retrieve relevant notes to answer user query Infer search queries to retrieve relevant notes to answer user query
@ -97,10 +71,11 @@ def extract_questions(
chat_history=chat_history, chat_history=chat_history,
text=text, text=text,
) )
messages = [ChatMessage(content=prompt, role="assistant")]
# Get Response from GPT # Get Response from GPT
response = completion_with_backoff( response = completion_with_backoff(
prompt=prompt, messages=messages,
model_name=model, model_name=model,
temperature=temperature, temperature=temperature,
max_tokens=max_tokens, max_tokens=max_tokens,
@ -111,7 +86,7 @@ def extract_questions(
# Extract, Clean Message from GPT's Response # Extract, Clean Message from GPT's Response
try: try:
questions = ( questions = (
response.strip(empty_escape_sequences) response.content.strip(empty_escape_sequences)
.replace("['", '["') .replace("['", '["')
.replace("']", '"]') .replace("']", '"]')
.replace("', '", '", "') .replace("', '", '", "')
@ -126,31 +101,6 @@ def extract_questions(
return questions return questions
def extract_search_type(text, model, api_key=None, temperature=0.5, max_tokens=100, verbose=0):
"""
Extract search type from user query using OpenAI's GPT
"""
# Setup Prompt to extract search type
prompt = prompts.search_type + f"{text}\nA:"
if verbose > 1:
print(f"Message -> Prompt: {text} -> {prompt}")
# Get Response from GPT
logger.debug(f"Prompt for GPT: {prompt}")
response = completion_with_backoff(
prompt=prompt,
model_name=model,
temperature=temperature,
max_tokens=max_tokens,
frequency_penalty=0.2,
model_kwargs={"stop": ["\n"]},
openai_api_key=api_key,
)
# Extract, Clean Message from GPT's Response
return json.loads(response.strip(empty_escape_sequences))
def converse( def converse(
references, references,
user_query, user_query,

View file

@ -37,12 +37,7 @@ Question: {query}
## Summarize Chat ## Summarize Chat
## -- ## --
summarize_chat = PromptTemplate.from_template( summarize_chat = PromptTemplate.from_template(
""" f"{personality.format()} Summarize the conversation from your first person perspective"
You are an AI. Summarize the conversation below from your perspective:
{text}
Summarize the conversation from the AI's first-person perspective:"""
) )
@ -102,7 +97,7 @@ A: You visited the Angkor Wat Temple in Cambodia with Pablo, Namita and Xi.
Q: What national parks did I go to last year? Q: What national parks did I go to last year?
["National park I visited in {last_new_year} dt>="{last_new_year_date}" dt<"{current_new_year_date}""] ["National park I visited in {last_new_year} dt>='{last_new_year_date}' dt<'{current_new_year_date}'"]
A: You visited the Grand Canyon and Yellowstone National Park in {last_new_year}. A: You visited the Grand Canyon and Yellowstone National Park in {last_new_year}.

View file

@ -9,7 +9,6 @@ import json
# External Packages # External Packages
from langchain.chat_models import ChatOpenAI from langchain.chat_models import ChatOpenAI
from langchain.llms import OpenAI
from langchain.schema import ChatMessage from langchain.schema import ChatMessage
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.base import BaseCallbackManager
@ -89,13 +88,11 @@ class StreamingChatCallbackHandler(StreamingStdOutCallbackHandler):
reraise=True, reraise=True,
) )
def completion_with_backoff(**kwargs): def completion_with_backoff(**kwargs):
prompt = kwargs.pop("prompt") messages = kwargs.pop("messages")
if "api_key" in kwargs: if not "openai_api_key" in kwargs:
kwargs["openai_api_key"] = kwargs.get("api_key")
else:
kwargs["openai_api_key"] = os.getenv("OPENAI_API_KEY") kwargs["openai_api_key"] = os.getenv("OPENAI_API_KEY")
llm = OpenAI(**kwargs, request_timeout=20, max_retries=1) llm = ChatOpenAI(**kwargs, request_timeout=20, max_retries=1)
return llm(prompt) return llm(messages=messages)
@retry( @retry(
@ -126,11 +123,12 @@ def llm_thread(g, messages, model_name, temperature, openai_api_key=None):
streaming=True, streaming=True,
verbose=True, verbose=True,
callback_manager=BaseCallbackManager([callback_handler]), callback_manager=BaseCallbackManager([callback_handler]),
model_name=model_name, model_name=model_name, # type: ignore
temperature=temperature, temperature=temperature,
openai_api_key=openai_api_key or os.getenv("OPENAI_API_KEY"), openai_api_key=openai_api_key or os.getenv("OPENAI_API_KEY"),
request_timeout=20, request_timeout=20,
max_retries=1, max_retries=1,
client=None,
) )
chat(messages=messages) chat(messages=messages)
@ -196,15 +194,6 @@ def reciprocal_conversation_to_chatml(message_pair):
return [ChatMessage(content=message, role=role) for message, role in zip(message_pair, ["user", "assistant"])] return [ChatMessage(content=message, role=role) for message, role in zip(message_pair, ["user", "assistant"])]
def message_to_prompt(
user_message, conversation_history="", gpt_message=None, start_sequence="\nAI:", restart_sequence="\nHuman:"
):
"""Create prompt for GPT from messages and conversation history"""
gpt_message = f" {gpt_message}" if gpt_message else ""
return f"{conversation_history}{restart_sequence} {user_message}{start_sequence}{gpt_message}"
def message_to_log(user_message, gpt_message, user_message_metadata={}, khoj_message_metadata={}, conversation_log=[]): def message_to_log(user_message, gpt_message, user_message_metadata={}, khoj_message_metadata={}, conversation_log=[]):
"""Create json logs from messages, metadata for conversation log""" """Create json logs from messages, metadata for conversation log"""
default_khoj_message_metadata = { default_khoj_message_metadata = {

View file

@ -15,7 +15,7 @@ from sentence_transformers import util
# Internal Packages # Internal Packages
from khoj.configure import configure_processor, configure_search from khoj.configure import configure_processor, configure_search
from khoj.processor.conversation.gpt import converse, extract_questions from khoj.processor.conversation.gpt import converse, extract_questions
from khoj.processor.conversation.utils import message_to_log, message_to_prompt from khoj.processor.conversation.utils import message_to_log, reciprocal_conversation_to_chatml
from khoj.search_type import image_search, text_search from khoj.search_type import image_search, text_search
from khoj.search_filter.date_filter import DateFilter from khoj.search_filter.date_filter import DateFilter
from khoj.search_filter.file_filter import FileFilter from khoj.search_filter.file_filter import FileFilter
@ -448,10 +448,9 @@ async def chat(
user_message_time: str, user_message_time: str,
compiled_references: List[str], compiled_references: List[str],
inferred_queries: List[str], inferred_queries: List[str],
chat_session: str,
meta_log, meta_log,
): ):
state.processor_config.conversation.chat_session = message_to_prompt(q, chat_session, gpt_message=gpt_response) state.processor_config.conversation.chat_session += reciprocal_conversation_to_chatml([q, gpt_response])
state.processor_config.conversation.meta_log["chat"] = message_to_log( state.processor_config.conversation.meta_log["chat"] = message_to_log(
q, q,
gpt_response, gpt_response,
@ -470,7 +469,6 @@ async def chat(
) )
# Load Conversation History # Load Conversation History
chat_session = state.processor_config.conversation.chat_session
meta_log = state.processor_config.conversation.meta_log meta_log = state.processor_config.conversation.meta_log
# If user query is empty, return nothing # If user query is empty, return nothing
@ -479,7 +477,6 @@ async def chat(
# Initialize Variables # Initialize Variables
api_key = state.processor_config.conversation.openai_api_key api_key = state.processor_config.conversation.openai_api_key
model = state.processor_config.conversation.model
chat_model = state.processor_config.conversation.chat_model chat_model = state.processor_config.conversation.chat_model
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
conversation_type = "general" if q.startswith("@general") else "notes" conversation_type = "general" if q.startswith("@general") else "notes"
@ -489,7 +486,7 @@ async def chat(
if conversation_type == "notes": if conversation_type == "notes":
# Infer search queries from user message # Infer search queries from user message
with timer("Extracting search queries took", logger): with timer("Extracting search queries took", logger):
inferred_queries = extract_questions(q, model=model, api_key=api_key, conversation_log=meta_log) inferred_queries = extract_questions(q, api_key=api_key, conversation_log=meta_log)
# Collate search results as context for GPT # Collate search results as context for GPT
with timer("Searching knowledge base took", logger): with timer("Searching knowledge base took", logger):
@ -525,7 +522,6 @@ async def chat(
user_message_time=user_message_time, user_message_time=user_message_time,
compiled_references=compiled_references, compiled_references=compiled_references,
inferred_queries=inferred_queries, inferred_queries=inferred_queries,
chat_session=chat_session,
meta_log=meta_log, meta_log=meta_log,
) )

View file

@ -1,64 +1,9 @@
# Standard Packages # Standard Packages
import logging import logging
from typing import Optional
# External Packages # External Packages
from fastapi import APIRouter from fastapi import APIRouter
# Internal Packages
from khoj.routers.api import search
from khoj.processor.conversation.gpt import (
answer,
extract_search_type,
)
from khoj.utils.state import SearchType
from khoj.utils.helpers import get_from_dict
from khoj.utils import state
# Initialize Router # Initialize Router
api_beta = APIRouter() api_beta = APIRouter()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Create Routes
@api_beta.get("/search")
def search_beta(q: str, n: Optional[int] = 1):
# Initialize Variables
model = state.processor_config.conversation.model
api_key = state.processor_config.conversation.openai_api_key
# Extract Search Type using GPT
try:
metadata = extract_search_type(q, model=model, api_key=api_key, verbose=state.verbose)
search_type = get_from_dict(metadata, "search-type")
except Exception as e:
return {"status": "error", "result": [str(e)], "type": None}
# Search
search_results = search(q, n=n, t=SearchType(search_type))
# Return response
return {"status": "ok", "result": search_results, "type": search_type}
@api_beta.get("/answer")
def answer_beta(q: str):
# Initialize Variables
model = state.processor_config.conversation.model
api_key = state.processor_config.conversation.openai_api_key
# Collate context for GPT
result_list = search(q, n=2, r=True, score_threshold=0, dedupe=False)
collated_result = "\n\n".join([f"# {item.additional['compiled']}" for item in result_list])
logger.debug(f"Reference Context:\n{collated_result}")
# Make GPT respond to user query using provided context
try:
gpt_response = answer(collated_result, user_query=q, model=model, api_key=api_key)
status = "ok"
except Exception as e:
gpt_response = str(e)
status = "error"
return {"status": status, "response": gpt_response}

View file

@ -23,7 +23,7 @@ class DateFilter(BaseFilter):
# - dt>="yesterday" dt<"tomorrow" # - dt>="yesterday" dt<"tomorrow"
# - dt>="last week" # - dt>="last week"
# - dt:"2 years ago" # - dt:"2 years ago"
date_regex = r"dt([:><=]{1,2})\"(.*?)\"" date_regex = r"dt([:><=]{1,2})[\"'](.*?)[\"']"
def __init__(self, entry_key="raw"): def __init__(self, entry_key="raw"):
self.entry_key = entry_key self.entry_key = entry_key

View file

@ -72,7 +72,7 @@ class ConversationProcessorConfigModel:
self.model = processor_config.model self.model = processor_config.model
self.chat_model = processor_config.chat_model self.chat_model = processor_config.chat_model
self.conversation_logfile = Path(processor_config.conversation_logfile) self.conversation_logfile = Path(processor_config.conversation_logfile)
self.chat_session = "" self.chat_session: List[str] = []
self.meta_log: dict = {} self.meta_log: dict = {}

View file

@ -33,9 +33,9 @@ def test_extract_question_with_date_filter_from_relative_day():
# Assert # Assert
expected_responses = [ expected_responses = [
('dt="1984-04-01"', ""), ("dt='1984-04-01'", ""),
('dt>="1984-04-01"', 'dt<"1984-04-02"'), ("dt>='1984-04-01'", "dt<'1984-04-02'"),
('dt>"1984-03-31"', 'dt<"1984-04-02"'), ("dt>'1984-03-31'", "dt<'1984-04-02'"),
] ]
assert len(response) == 1 assert len(response) == 1
assert any([start in response[0] and end in response[0] for start, end in expected_responses]), ( assert any([start in response[0] and end in response[0] for start, end in expected_responses]), (
@ -51,7 +51,7 @@ def test_extract_question_with_date_filter_from_relative_month():
response = extract_questions("Which countries did I visit last month?") response = extract_questions("Which countries did I visit last month?")
# Assert # Assert
expected_responses = [('dt>="1984-03-01"', 'dt<"1984-04-01"'), ('dt>="1984-03-01"', 'dt<="1984-03-31"')] expected_responses = [("dt>='1984-03-01'", "dt<'1984-04-01'"), ("dt>='1984-03-01'", "dt<='1984-03-31'")]
assert len(response) == 1 assert len(response) == 1
assert any([start in response[0] and end in response[0] for start, end in expected_responses]), ( assert any([start in response[0] and end in response[0] for start, end in expected_responses]), (
"Expected date filter to limit to March 1984 in response but got: " + response[0] "Expected date filter to limit to March 1984 in response but got: " + response[0]
@ -67,9 +67,9 @@ def test_extract_question_with_date_filter_from_relative_year():
# Assert # Assert
expected_responses = [ expected_responses = [
('dt>="1984-01-01"', ""), ("dt>='1984-01-01'", ""),
('dt>="1984-01-01"', 'dt<"1985-01-01"'), ("dt>='1984-01-01'", "dt<'1985-01-01'"),
('dt>="1984-01-01"', 'dt<="1984-12-31"'), ("dt>='1984-01-01'", "dt<='1984-12-31'"),
] ]
assert len(response) == 1 assert len(response) == 1
assert any([start in response[0] and end in response[0] for start, end in expected_responses]), ( assert any([start in response[0] and end in response[0] for start, end in expected_responses]), (
@ -172,8 +172,8 @@ def test_generate_search_query_with_date_and_context_from_chat_history():
# Assert # Assert
expected_responses = [ expected_responses = [
('dt>="2000-04-01"', 'dt<"2000-05-01"'), ("dt>='2000-04-01'", "dt<'2000-05-01'"),
('dt>="2000-04-01"', 'dt<="2000-04-31"'), ("dt>='2000-04-01'", "dt<='2000-04-30'"),
] ]
assert len(response) == 1 assert len(response) == 1
assert "Masai Mara" in response[0] assert "Masai Mara" in response[0]