mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 23:48:56 +01:00
Move remaining chat actors to use OpenAI chat models
- Deprecate the unused beta /answer and /search type identification endpoints and associated GPT functions - Update extract_questions to use GPT4 - Update summarize method to default to GPT-3.5 - Update date filter to support quoting values in single quotes too. So now both dt>'2023-04-01' and dt>"2023-04-01" should work - Remove "model" field from chat settings on the web interface
This commit is contained in:
commit
4b79d8216f
10 changed files with 49 additions and 174 deletions
|
@ -205,7 +205,7 @@ def configure_conversation_processor(conversation_processor_config):
|
|||
else:
|
||||
# Initialize Conversation Logs
|
||||
conversation_processor.meta_log = {}
|
||||
conversation_processor.chat_session = ""
|
||||
conversation_processor.chat_session = []
|
||||
|
||||
return conversation_processor
|
||||
|
||||
|
@ -225,9 +225,9 @@ def save_chat_session():
|
|||
chat_session = state.processor_config.conversation.chat_session
|
||||
openai_api_key = state.processor_config.conversation.openai_api_key
|
||||
conversation_log = state.processor_config.conversation.meta_log
|
||||
model = state.processor_config.conversation.model
|
||||
chat_model = state.processor_config.conversation.chat_model
|
||||
session = {
|
||||
"summary": summarize(chat_session, summary_type="chat", model=model, api_key=openai_api_key),
|
||||
"summary": summarize(chat_session, model=chat_model, api_key=openai_api_key),
|
||||
"session-start": conversation_log.get("session", [{"session-end": 0}])[-1]["session-end"],
|
||||
"session-end": len(conversation_log["chat"]),
|
||||
}
|
||||
|
@ -242,7 +242,7 @@ def save_chat_session():
|
|||
with open(conversation_logfile, "w+", encoding="utf-8") as logfile:
|
||||
json.dump(conversation_log, logfile, indent=2)
|
||||
|
||||
state.processor_config.conversation.chat_session = None
|
||||
state.processor_config.conversation.chat_session = []
|
||||
logger.info("📩 Saved current chat session to conversation logs")
|
||||
|
||||
|
||||
|
|
|
@ -16,9 +16,17 @@
|
|||
<input type="text" id="openai-api-key" name="openai-api-key" value="{{ current_config['openai_api_key'] }}">
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>
|
||||
<label for="chat-model">Chat Model</label>
|
||||
</td>
|
||||
<td>
|
||||
<input type="text" id="chat-model" name="chat-model" value="{{ current_config['chat_model'] }}">
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
<table >
|
||||
<tr style="display: none;">
|
||||
<table style="display: none;">
|
||||
<tr>
|
||||
<td>
|
||||
<label for="conversation-logfile">Conversation Logfile</label>
|
||||
</td>
|
||||
|
@ -34,14 +42,6 @@
|
|||
<input type="text" id="model" name="model" value="{{ current_config['model'] }}">
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>
|
||||
<label for="chat-model">Chat Model</label>
|
||||
</td>
|
||||
<td>
|
||||
<input type="text" id="chat-model" name="chat-model" value="{{ current_config['chat_model'] }}">
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
<div class="section">
|
||||
<div id="success" style="display: none;" ></div>
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
# Standard Packages
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
# External Packages
|
||||
from langchain.schema import ChatMessage
|
||||
|
||||
# Internal Packages
|
||||
from khoj.utils.constants import empty_escape_sequences
|
||||
from khoj.processor.conversation import prompts
|
||||
|
@ -17,44 +19,16 @@ from khoj.processor.conversation.utils import (
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def answer(text, user_query, model, api_key=None, temperature=0.5, max_tokens=500):
|
||||
def summarize(session, model, api_key=None, temperature=0.5, max_tokens=200):
|
||||
"""
|
||||
Answer user query using provided text as reference with OpenAI's GPT
|
||||
Summarize conversation session using the specified OpenAI chat model
|
||||
"""
|
||||
# Setup Prompt from arguments
|
||||
prompt = prompts.answer.format(text=text, user_query=user_query)
|
||||
messages = [ChatMessage(content=prompts.summarize_chat.format(), role="system")] + session
|
||||
|
||||
# Get Response from GPT
|
||||
logger.debug(f"Prompt for GPT: {prompt}")
|
||||
logger.debug(f"Prompt for GPT: {messages}")
|
||||
response = completion_with_backoff(
|
||||
prompt=prompt,
|
||||
model_name=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
model_kwargs={"stop": ['"""']},
|
||||
openai_api_key=api_key,
|
||||
)
|
||||
|
||||
# Extract, Clean Message from GPT's Response
|
||||
return str(response).replace("\n\n", "")
|
||||
|
||||
|
||||
def summarize(text, summary_type, model, user_query=None, api_key=None, temperature=0.5, max_tokens=200):
|
||||
"""
|
||||
Summarize user input using OpenAI's GPT
|
||||
"""
|
||||
# Setup Prompt based on Summary Type
|
||||
if summary_type == "chat":
|
||||
prompt = prompts.summarize_chat.format(text=text)
|
||||
elif summary_type == "notes":
|
||||
prompt = prompts.summarize_notes.format(text=text, user_query=user_query)
|
||||
else:
|
||||
raise ValueError(f"Invalid summary type: {summary_type}")
|
||||
|
||||
# Get Response from GPT
|
||||
logger.debug(f"Prompt for GPT: {prompt}")
|
||||
response = completion_with_backoff(
|
||||
prompt=prompt,
|
||||
messages=messages,
|
||||
model_name=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
|
@ -64,11 +38,11 @@ def summarize(text, summary_type, model, user_query=None, api_key=None, temperat
|
|||
)
|
||||
|
||||
# Extract, Clean Message from GPT's Response
|
||||
return str(response).replace("\n\n", "")
|
||||
return str(response.content).replace("\n\n", "")
|
||||
|
||||
|
||||
def extract_questions(
|
||||
text, model: Optional[str] = "text-davinci-003", conversation_log={}, api_key=None, temperature=0, max_tokens=100
|
||||
text, model: Optional[str] = "gpt-4", conversation_log={}, api_key=None, temperature=0, max_tokens=100
|
||||
):
|
||||
"""
|
||||
Infer search queries to retrieve relevant notes to answer user query
|
||||
|
@ -97,10 +71,11 @@ def extract_questions(
|
|||
chat_history=chat_history,
|
||||
text=text,
|
||||
)
|
||||
messages = [ChatMessage(content=prompt, role="assistant")]
|
||||
|
||||
# Get Response from GPT
|
||||
response = completion_with_backoff(
|
||||
prompt=prompt,
|
||||
messages=messages,
|
||||
model_name=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
|
@ -111,7 +86,7 @@ def extract_questions(
|
|||
# Extract, Clean Message from GPT's Response
|
||||
try:
|
||||
questions = (
|
||||
response.strip(empty_escape_sequences)
|
||||
response.content.strip(empty_escape_sequences)
|
||||
.replace("['", '["')
|
||||
.replace("']", '"]')
|
||||
.replace("', '", '", "')
|
||||
|
@ -126,31 +101,6 @@ def extract_questions(
|
|||
return questions
|
||||
|
||||
|
||||
def extract_search_type(text, model, api_key=None, temperature=0.5, max_tokens=100, verbose=0):
|
||||
"""
|
||||
Extract search type from user query using OpenAI's GPT
|
||||
"""
|
||||
# Setup Prompt to extract search type
|
||||
prompt = prompts.search_type + f"{text}\nA:"
|
||||
if verbose > 1:
|
||||
print(f"Message -> Prompt: {text} -> {prompt}")
|
||||
|
||||
# Get Response from GPT
|
||||
logger.debug(f"Prompt for GPT: {prompt}")
|
||||
response = completion_with_backoff(
|
||||
prompt=prompt,
|
||||
model_name=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
frequency_penalty=0.2,
|
||||
model_kwargs={"stop": ["\n"]},
|
||||
openai_api_key=api_key,
|
||||
)
|
||||
|
||||
# Extract, Clean Message from GPT's Response
|
||||
return json.loads(response.strip(empty_escape_sequences))
|
||||
|
||||
|
||||
def converse(
|
||||
references,
|
||||
user_query,
|
||||
|
|
|
@ -37,12 +37,7 @@ Question: {query}
|
|||
## Summarize Chat
|
||||
## --
|
||||
summarize_chat = PromptTemplate.from_template(
|
||||
"""
|
||||
You are an AI. Summarize the conversation below from your perspective:
|
||||
|
||||
{text}
|
||||
|
||||
Summarize the conversation from the AI's first-person perspective:"""
|
||||
f"{personality.format()} Summarize the conversation from your first person perspective"
|
||||
)
|
||||
|
||||
|
||||
|
@ -102,7 +97,7 @@ A: You visited the Angkor Wat Temple in Cambodia with Pablo, Namita and Xi.
|
|||
|
||||
Q: What national parks did I go to last year?
|
||||
|
||||
["National park I visited in {last_new_year} dt>="{last_new_year_date}" dt<"{current_new_year_date}""]
|
||||
["National park I visited in {last_new_year} dt>='{last_new_year_date}' dt<'{current_new_year_date}'"]
|
||||
|
||||
A: You visited the Grand Canyon and Yellowstone National Park in {last_new_year}.
|
||||
|
||||
|
|
|
@ -9,7 +9,6 @@ import json
|
|||
|
||||
# External Packages
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.llms import OpenAI
|
||||
from langchain.schema import ChatMessage
|
||||
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
|
@ -89,13 +88,11 @@ class StreamingChatCallbackHandler(StreamingStdOutCallbackHandler):
|
|||
reraise=True,
|
||||
)
|
||||
def completion_with_backoff(**kwargs):
|
||||
prompt = kwargs.pop("prompt")
|
||||
if "api_key" in kwargs:
|
||||
kwargs["openai_api_key"] = kwargs.get("api_key")
|
||||
else:
|
||||
messages = kwargs.pop("messages")
|
||||
if not "openai_api_key" in kwargs:
|
||||
kwargs["openai_api_key"] = os.getenv("OPENAI_API_KEY")
|
||||
llm = OpenAI(**kwargs, request_timeout=20, max_retries=1)
|
||||
return llm(prompt)
|
||||
llm = ChatOpenAI(**kwargs, request_timeout=20, max_retries=1)
|
||||
return llm(messages=messages)
|
||||
|
||||
|
||||
@retry(
|
||||
|
@ -126,11 +123,12 @@ def llm_thread(g, messages, model_name, temperature, openai_api_key=None):
|
|||
streaming=True,
|
||||
verbose=True,
|
||||
callback_manager=BaseCallbackManager([callback_handler]),
|
||||
model_name=model_name,
|
||||
model_name=model_name, # type: ignore
|
||||
temperature=temperature,
|
||||
openai_api_key=openai_api_key or os.getenv("OPENAI_API_KEY"),
|
||||
request_timeout=20,
|
||||
max_retries=1,
|
||||
client=None,
|
||||
)
|
||||
|
||||
chat(messages=messages)
|
||||
|
@ -196,15 +194,6 @@ def reciprocal_conversation_to_chatml(message_pair):
|
|||
return [ChatMessage(content=message, role=role) for message, role in zip(message_pair, ["user", "assistant"])]
|
||||
|
||||
|
||||
def message_to_prompt(
|
||||
user_message, conversation_history="", gpt_message=None, start_sequence="\nAI:", restart_sequence="\nHuman:"
|
||||
):
|
||||
"""Create prompt for GPT from messages and conversation history"""
|
||||
gpt_message = f" {gpt_message}" if gpt_message else ""
|
||||
|
||||
return f"{conversation_history}{restart_sequence} {user_message}{start_sequence}{gpt_message}"
|
||||
|
||||
|
||||
def message_to_log(user_message, gpt_message, user_message_metadata={}, khoj_message_metadata={}, conversation_log=[]):
|
||||
"""Create json logs from messages, metadata for conversation log"""
|
||||
default_khoj_message_metadata = {
|
||||
|
|
|
@ -15,7 +15,7 @@ from sentence_transformers import util
|
|||
# Internal Packages
|
||||
from khoj.configure import configure_processor, configure_search
|
||||
from khoj.processor.conversation.gpt import converse, extract_questions
|
||||
from khoj.processor.conversation.utils import message_to_log, message_to_prompt
|
||||
from khoj.processor.conversation.utils import message_to_log, reciprocal_conversation_to_chatml
|
||||
from khoj.search_type import image_search, text_search
|
||||
from khoj.search_filter.date_filter import DateFilter
|
||||
from khoj.search_filter.file_filter import FileFilter
|
||||
|
@ -448,10 +448,9 @@ async def chat(
|
|||
user_message_time: str,
|
||||
compiled_references: List[str],
|
||||
inferred_queries: List[str],
|
||||
chat_session: str,
|
||||
meta_log,
|
||||
):
|
||||
state.processor_config.conversation.chat_session = message_to_prompt(q, chat_session, gpt_message=gpt_response)
|
||||
state.processor_config.conversation.chat_session += reciprocal_conversation_to_chatml([q, gpt_response])
|
||||
state.processor_config.conversation.meta_log["chat"] = message_to_log(
|
||||
q,
|
||||
gpt_response,
|
||||
|
@ -470,7 +469,6 @@ async def chat(
|
|||
)
|
||||
|
||||
# Load Conversation History
|
||||
chat_session = state.processor_config.conversation.chat_session
|
||||
meta_log = state.processor_config.conversation.meta_log
|
||||
|
||||
# If user query is empty, return nothing
|
||||
|
@ -479,7 +477,6 @@ async def chat(
|
|||
|
||||
# Initialize Variables
|
||||
api_key = state.processor_config.conversation.openai_api_key
|
||||
model = state.processor_config.conversation.model
|
||||
chat_model = state.processor_config.conversation.chat_model
|
||||
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
conversation_type = "general" if q.startswith("@general") else "notes"
|
||||
|
@ -489,7 +486,7 @@ async def chat(
|
|||
if conversation_type == "notes":
|
||||
# Infer search queries from user message
|
||||
with timer("Extracting search queries took", logger):
|
||||
inferred_queries = extract_questions(q, model=model, api_key=api_key, conversation_log=meta_log)
|
||||
inferred_queries = extract_questions(q, api_key=api_key, conversation_log=meta_log)
|
||||
|
||||
# Collate search results as context for GPT
|
||||
with timer("Searching knowledge base took", logger):
|
||||
|
@ -525,7 +522,6 @@ async def chat(
|
|||
user_message_time=user_message_time,
|
||||
compiled_references=compiled_references,
|
||||
inferred_queries=inferred_queries,
|
||||
chat_session=chat_session,
|
||||
meta_log=meta_log,
|
||||
)
|
||||
|
||||
|
|
|
@ -1,64 +1,9 @@
|
|||
# Standard Packages
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
# External Packages
|
||||
from fastapi import APIRouter
|
||||
|
||||
# Internal Packages
|
||||
from khoj.routers.api import search
|
||||
from khoj.processor.conversation.gpt import (
|
||||
answer,
|
||||
extract_search_type,
|
||||
)
|
||||
from khoj.utils.state import SearchType
|
||||
from khoj.utils.helpers import get_from_dict
|
||||
from khoj.utils import state
|
||||
|
||||
|
||||
# Initialize Router
|
||||
api_beta = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Create Routes
|
||||
@api_beta.get("/search")
|
||||
def search_beta(q: str, n: Optional[int] = 1):
|
||||
# Initialize Variables
|
||||
model = state.processor_config.conversation.model
|
||||
api_key = state.processor_config.conversation.openai_api_key
|
||||
|
||||
# Extract Search Type using GPT
|
||||
try:
|
||||
metadata = extract_search_type(q, model=model, api_key=api_key, verbose=state.verbose)
|
||||
search_type = get_from_dict(metadata, "search-type")
|
||||
except Exception as e:
|
||||
return {"status": "error", "result": [str(e)], "type": None}
|
||||
|
||||
# Search
|
||||
search_results = search(q, n=n, t=SearchType(search_type))
|
||||
|
||||
# Return response
|
||||
return {"status": "ok", "result": search_results, "type": search_type}
|
||||
|
||||
|
||||
@api_beta.get("/answer")
|
||||
def answer_beta(q: str):
|
||||
# Initialize Variables
|
||||
model = state.processor_config.conversation.model
|
||||
api_key = state.processor_config.conversation.openai_api_key
|
||||
|
||||
# Collate context for GPT
|
||||
result_list = search(q, n=2, r=True, score_threshold=0, dedupe=False)
|
||||
collated_result = "\n\n".join([f"# {item.additional['compiled']}" for item in result_list])
|
||||
logger.debug(f"Reference Context:\n{collated_result}")
|
||||
|
||||
# Make GPT respond to user query using provided context
|
||||
try:
|
||||
gpt_response = answer(collated_result, user_query=q, model=model, api_key=api_key)
|
||||
status = "ok"
|
||||
except Exception as e:
|
||||
gpt_response = str(e)
|
||||
status = "error"
|
||||
|
||||
return {"status": status, "response": gpt_response}
|
||||
|
|
|
@ -23,7 +23,7 @@ class DateFilter(BaseFilter):
|
|||
# - dt>="yesterday" dt<"tomorrow"
|
||||
# - dt>="last week"
|
||||
# - dt:"2 years ago"
|
||||
date_regex = r"dt([:><=]{1,2})\"(.*?)\""
|
||||
date_regex = r"dt([:><=]{1,2})[\"'](.*?)[\"']"
|
||||
|
||||
def __init__(self, entry_key="raw"):
|
||||
self.entry_key = entry_key
|
||||
|
|
|
@ -72,7 +72,7 @@ class ConversationProcessorConfigModel:
|
|||
self.model = processor_config.model
|
||||
self.chat_model = processor_config.chat_model
|
||||
self.conversation_logfile = Path(processor_config.conversation_logfile)
|
||||
self.chat_session = ""
|
||||
self.chat_session: List[str] = []
|
||||
self.meta_log: dict = {}
|
||||
|
||||
|
||||
|
|
|
@ -33,9 +33,9 @@ def test_extract_question_with_date_filter_from_relative_day():
|
|||
|
||||
# Assert
|
||||
expected_responses = [
|
||||
('dt="1984-04-01"', ""),
|
||||
('dt>="1984-04-01"', 'dt<"1984-04-02"'),
|
||||
('dt>"1984-03-31"', 'dt<"1984-04-02"'),
|
||||
("dt='1984-04-01'", ""),
|
||||
("dt>='1984-04-01'", "dt<'1984-04-02'"),
|
||||
("dt>'1984-03-31'", "dt<'1984-04-02'"),
|
||||
]
|
||||
assert len(response) == 1
|
||||
assert any([start in response[0] and end in response[0] for start, end in expected_responses]), (
|
||||
|
@ -51,7 +51,7 @@ def test_extract_question_with_date_filter_from_relative_month():
|
|||
response = extract_questions("Which countries did I visit last month?")
|
||||
|
||||
# Assert
|
||||
expected_responses = [('dt>="1984-03-01"', 'dt<"1984-04-01"'), ('dt>="1984-03-01"', 'dt<="1984-03-31"')]
|
||||
expected_responses = [("dt>='1984-03-01'", "dt<'1984-04-01'"), ("dt>='1984-03-01'", "dt<='1984-03-31'")]
|
||||
assert len(response) == 1
|
||||
assert any([start in response[0] and end in response[0] for start, end in expected_responses]), (
|
||||
"Expected date filter to limit to March 1984 in response but got: " + response[0]
|
||||
|
@ -67,9 +67,9 @@ def test_extract_question_with_date_filter_from_relative_year():
|
|||
|
||||
# Assert
|
||||
expected_responses = [
|
||||
('dt>="1984-01-01"', ""),
|
||||
('dt>="1984-01-01"', 'dt<"1985-01-01"'),
|
||||
('dt>="1984-01-01"', 'dt<="1984-12-31"'),
|
||||
("dt>='1984-01-01'", ""),
|
||||
("dt>='1984-01-01'", "dt<'1985-01-01'"),
|
||||
("dt>='1984-01-01'", "dt<='1984-12-31'"),
|
||||
]
|
||||
assert len(response) == 1
|
||||
assert any([start in response[0] and end in response[0] for start, end in expected_responses]), (
|
||||
|
@ -172,8 +172,8 @@ def test_generate_search_query_with_date_and_context_from_chat_history():
|
|||
|
||||
# Assert
|
||||
expected_responses = [
|
||||
('dt>="2000-04-01"', 'dt<"2000-05-01"'),
|
||||
('dt>="2000-04-01"', 'dt<="2000-04-31"'),
|
||||
("dt>='2000-04-01'", "dt<'2000-05-01'"),
|
||||
("dt>='2000-04-01'", "dt<='2000-04-30'"),
|
||||
]
|
||||
assert len(response) == 1
|
||||
assert "Masai Mara" in response[0]
|
||||
|
|
Loading…
Reference in a new issue