diff --git a/src/khoj/processor/conversation/gpt.py b/src/khoj/processor/conversation/gpt.py
index 8051afdb..226af3fb 100644
--- a/src/khoj/processor/conversation/gpt.py
+++ b/src/khoj/processor/conversation/gpt.py
@@ -1,9 +1,11 @@
# Standard Packages
-import json
import logging
from datetime import datetime
from typing import Optional
+# External Packages
+from langchain.schema import ChatMessage
+
# Internal Packages
from khoj.utils.constants import empty_escape_sequences
from khoj.processor.conversation import prompts
@@ -17,44 +19,16 @@ from khoj.processor.conversation.utils import (
logger = logging.getLogger(__name__)
-def answer(text, user_query, model, api_key=None, temperature=0.5, max_tokens=500):
+def summarize(session, model, api_key=None, temperature=0.5, max_tokens=200):
"""
- Answer user query using provided text as reference with OpenAI's GPT
+ Summarize conversation session using the specified OpenAI chat model
"""
- # Setup Prompt from arguments
- prompt = prompts.answer.format(text=text, user_query=user_query)
+ messages = [ChatMessage(content=prompts.summarize_chat.format(), role="system")] + session
# Get Response from GPT
- logger.debug(f"Prompt for GPT: {prompt}")
+ logger.debug(f"Prompt for GPT: {messages}")
response = completion_with_backoff(
- prompt=prompt,
- model_name=model,
- temperature=temperature,
- max_tokens=max_tokens,
- model_kwargs={"stop": ['"""']},
- openai_api_key=api_key,
- )
-
- # Extract, Clean Message from GPT's Response
- return str(response).replace("\n\n", "")
-
-
-def summarize(text, summary_type, model, user_query=None, api_key=None, temperature=0.5, max_tokens=200):
- """
- Summarize user input using OpenAI's GPT
- """
- # Setup Prompt based on Summary Type
- if summary_type == "chat":
- prompt = prompts.summarize_chat.format(text=text)
- elif summary_type == "notes":
- prompt = prompts.summarize_notes.format(text=text, user_query=user_query)
- else:
- raise ValueError(f"Invalid summary type: {summary_type}")
-
- # Get Response from GPT
- logger.debug(f"Prompt for GPT: {prompt}")
- response = completion_with_backoff(
- prompt=prompt,
+ messages=messages,
model_name=model,
temperature=temperature,
max_tokens=max_tokens,
@@ -64,11 +38,11 @@ def summarize(text, summary_type, model, user_query=None, api_key=None, temperat
)
# Extract, Clean Message from GPT's Response
- return str(response).replace("\n\n", "")
+ return str(response.content).replace("\n\n", "")
def extract_questions(
- text, model: Optional[str] = "text-davinci-003", conversation_log={}, api_key=None, temperature=0, max_tokens=100
+ text, model: Optional[str] = "gpt-4", conversation_log={}, api_key=None, temperature=0, max_tokens=100
):
"""
Infer search queries to retrieve relevant notes to answer user query
@@ -97,10 +71,11 @@ def extract_questions(
chat_history=chat_history,
text=text,
)
+ messages = [ChatMessage(content=prompt, role="assistant")]
# Get Response from GPT
response = completion_with_backoff(
- prompt=prompt,
+ messages=messages,
model_name=model,
temperature=temperature,
max_tokens=max_tokens,
@@ -111,7 +86,7 @@ def extract_questions(
# Extract, Clean Message from GPT's Response
try:
questions = (
- response.strip(empty_escape_sequences)
+ response.content.strip(empty_escape_sequences)
.replace("['", '["')
.replace("']", '"]')
.replace("', '", '", "')
@@ -126,31 +101,6 @@ def extract_questions(
return questions
-def extract_search_type(text, model, api_key=None, temperature=0.5, max_tokens=100, verbose=0):
- """
- Extract search type from user query using OpenAI's GPT
- """
- # Setup Prompt to extract search type
- prompt = prompts.search_type + f"{text}\nA:"
- if verbose > 1:
- print(f"Message -> Prompt: {text} -> {prompt}")
-
- # Get Response from GPT
- logger.debug(f"Prompt for GPT: {prompt}")
- response = completion_with_backoff(
- prompt=prompt,
- model_name=model,
- temperature=temperature,
- max_tokens=max_tokens,
- frequency_penalty=0.2,
- model_kwargs={"stop": ["\n"]},
- openai_api_key=api_key,
- )
-
- # Extract, Clean Message from GPT's Response
- return json.loads(response.strip(empty_escape_sequences))
-
-
def converse(
references,
user_query,
diff --git a/src/khoj/processor/conversation/prompts.py b/src/khoj/processor/conversation/prompts.py
index 04b5d108..c04e9042 100644
--- a/src/khoj/processor/conversation/prompts.py
+++ b/src/khoj/processor/conversation/prompts.py
@@ -37,12 +37,7 @@ Question: {query}
## Summarize Chat
## --
summarize_chat = PromptTemplate.from_template(
- """
-You are an AI. Summarize the conversation below from your perspective:
-
-{text}
-
-Summarize the conversation from the AI's first-person perspective:"""
+ f"{personality.format()} Summarize the conversation from your first person perspective"
)
@@ -102,7 +97,7 @@ A: You visited the Angkor Wat Temple in Cambodia with Pablo, Namita and Xi.
Q: What national parks did I go to last year?
-["National park I visited in {last_new_year} dt>="{last_new_year_date}" dt<"{current_new_year_date}""]
+["National park I visited in {last_new_year} dt>='{last_new_year_date}' dt<'{current_new_year_date}'"]
A: You visited the Grand Canyon and Yellowstone National Park in {last_new_year}.
diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py
index 99084bf0..e77b7899 100644
--- a/src/khoj/processor/conversation/utils.py
+++ b/src/khoj/processor/conversation/utils.py
@@ -9,7 +9,6 @@ import json
# External Packages
from langchain.chat_models import ChatOpenAI
-from langchain.llms import OpenAI
from langchain.schema import ChatMessage
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.callbacks.base import BaseCallbackManager
@@ -89,13 +88,11 @@ class StreamingChatCallbackHandler(StreamingStdOutCallbackHandler):
reraise=True,
)
def completion_with_backoff(**kwargs):
- prompt = kwargs.pop("prompt")
- if "api_key" in kwargs:
- kwargs["openai_api_key"] = kwargs.get("api_key")
- else:
+ messages = kwargs.pop("messages")
+ if not "openai_api_key" in kwargs:
kwargs["openai_api_key"] = os.getenv("OPENAI_API_KEY")
- llm = OpenAI(**kwargs, request_timeout=20, max_retries=1)
- return llm(prompt)
+ llm = ChatOpenAI(**kwargs, request_timeout=20, max_retries=1)
+ return llm(messages=messages)
@retry(
@@ -126,11 +123,12 @@ def llm_thread(g, messages, model_name, temperature, openai_api_key=None):
streaming=True,
verbose=True,
callback_manager=BaseCallbackManager([callback_handler]),
- model_name=model_name,
+ model_name=model_name, # type: ignore
temperature=temperature,
openai_api_key=openai_api_key or os.getenv("OPENAI_API_KEY"),
request_timeout=20,
max_retries=1,
+ client=None,
)
chat(messages=messages)
@@ -196,15 +194,6 @@ def reciprocal_conversation_to_chatml(message_pair):
return [ChatMessage(content=message, role=role) for message, role in zip(message_pair, ["user", "assistant"])]
-def message_to_prompt(
- user_message, conversation_history="", gpt_message=None, start_sequence="\nAI:", restart_sequence="\nHuman:"
-):
- """Create prompt for GPT from messages and conversation history"""
- gpt_message = f" {gpt_message}" if gpt_message else ""
-
- return f"{conversation_history}{restart_sequence} {user_message}{start_sequence}{gpt_message}"
-
-
def message_to_log(user_message, gpt_message, user_message_metadata={}, khoj_message_metadata={}, conversation_log=[]):
"""Create json logs from messages, metadata for conversation log"""
default_khoj_message_metadata = {
diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py
index e87a80d0..7ccdf5cd 100644
--- a/src/khoj/routers/api.py
+++ b/src/khoj/routers/api.py
@@ -15,7 +15,7 @@ from sentence_transformers import util
# Internal Packages
from khoj.configure import configure_processor, configure_search
from khoj.processor.conversation.gpt import converse, extract_questions
-from khoj.processor.conversation.utils import message_to_log, message_to_prompt
+from khoj.processor.conversation.utils import message_to_log, reciprocal_conversation_to_chatml
from khoj.search_type import image_search, text_search
from khoj.search_filter.date_filter import DateFilter
from khoj.search_filter.file_filter import FileFilter
@@ -448,10 +448,9 @@ async def chat(
user_message_time: str,
compiled_references: List[str],
inferred_queries: List[str],
- chat_session: str,
meta_log,
):
- state.processor_config.conversation.chat_session = message_to_prompt(q, chat_session, gpt_message=gpt_response)
+ state.processor_config.conversation.chat_session += reciprocal_conversation_to_chatml([q, gpt_response])
state.processor_config.conversation.meta_log["chat"] = message_to_log(
q,
gpt_response,
@@ -470,7 +469,6 @@ async def chat(
)
# Load Conversation History
- chat_session = state.processor_config.conversation.chat_session
meta_log = state.processor_config.conversation.meta_log
# If user query is empty, return nothing
@@ -479,7 +477,6 @@ async def chat(
# Initialize Variables
api_key = state.processor_config.conversation.openai_api_key
- model = state.processor_config.conversation.model
chat_model = state.processor_config.conversation.chat_model
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
conversation_type = "general" if q.startswith("@general") else "notes"
@@ -489,7 +486,7 @@ async def chat(
if conversation_type == "notes":
# Infer search queries from user message
with timer("Extracting search queries took", logger):
- inferred_queries = extract_questions(q, model=model, api_key=api_key, conversation_log=meta_log)
+ inferred_queries = extract_questions(q, api_key=api_key, conversation_log=meta_log)
# Collate search results as context for GPT
with timer("Searching knowledge base took", logger):
@@ -525,7 +522,6 @@ async def chat(
user_message_time=user_message_time,
compiled_references=compiled_references,
inferred_queries=inferred_queries,
- chat_session=chat_session,
meta_log=meta_log,
)
diff --git a/src/khoj/routers/api_beta.py b/src/khoj/routers/api_beta.py
index 01b7a086..83cca612 100644
--- a/src/khoj/routers/api_beta.py
+++ b/src/khoj/routers/api_beta.py
@@ -1,64 +1,9 @@
# Standard Packages
import logging
-from typing import Optional
# External Packages
from fastapi import APIRouter
-# Internal Packages
-from khoj.routers.api import search
-from khoj.processor.conversation.gpt import (
- answer,
- extract_search_type,
-)
-from khoj.utils.state import SearchType
-from khoj.utils.helpers import get_from_dict
-from khoj.utils import state
-
-
# Initialize Router
api_beta = APIRouter()
logger = logging.getLogger(__name__)
-
-
-# Create Routes
-@api_beta.get("/search")
-def search_beta(q: str, n: Optional[int] = 1):
- # Initialize Variables
- model = state.processor_config.conversation.model
- api_key = state.processor_config.conversation.openai_api_key
-
- # Extract Search Type using GPT
- try:
- metadata = extract_search_type(q, model=model, api_key=api_key, verbose=state.verbose)
- search_type = get_from_dict(metadata, "search-type")
- except Exception as e:
- return {"status": "error", "result": [str(e)], "type": None}
-
- # Search
- search_results = search(q, n=n, t=SearchType(search_type))
-
- # Return response
- return {"status": "ok", "result": search_results, "type": search_type}
-
-
-@api_beta.get("/answer")
-def answer_beta(q: str):
- # Initialize Variables
- model = state.processor_config.conversation.model
- api_key = state.processor_config.conversation.openai_api_key
-
- # Collate context for GPT
- result_list = search(q, n=2, r=True, score_threshold=0, dedupe=False)
- collated_result = "\n\n".join([f"# {item.additional['compiled']}" for item in result_list])
- logger.debug(f"Reference Context:\n{collated_result}")
-
- # Make GPT respond to user query using provided context
- try:
- gpt_response = answer(collated_result, user_query=q, model=model, api_key=api_key)
- status = "ok"
- except Exception as e:
- gpt_response = str(e)
- status = "error"
-
- return {"status": status, "response": gpt_response}
diff --git a/src/khoj/search_filter/date_filter.py b/src/khoj/search_filter/date_filter.py
index be07eefd..b612f16b 100644
--- a/src/khoj/search_filter/date_filter.py
+++ b/src/khoj/search_filter/date_filter.py
@@ -23,7 +23,7 @@ class DateFilter(BaseFilter):
# - dt>="yesterday" dt<"tomorrow"
# - dt>="last week"
# - dt:"2 years ago"
- date_regex = r"dt([:><=]{1,2})\"(.*?)\""
+ date_regex = r"dt([:><=]{1,2})[\"'](.*?)[\"']"
def __init__(self, entry_key="raw"):
self.entry_key = entry_key
diff --git a/src/khoj/utils/config.py b/src/khoj/utils/config.py
index 3adc6e9d..155cdcc6 100644
--- a/src/khoj/utils/config.py
+++ b/src/khoj/utils/config.py
@@ -72,7 +72,7 @@ class ConversationProcessorConfigModel:
self.model = processor_config.model
self.chat_model = processor_config.chat_model
self.conversation_logfile = Path(processor_config.conversation_logfile)
- self.chat_session = ""
+ self.chat_session: List[str] = []
self.meta_log: dict = {}
diff --git a/tests/test_chat_actors.py b/tests/test_chat_actors.py
index 9f8e821d..e645fc13 100644
--- a/tests/test_chat_actors.py
+++ b/tests/test_chat_actors.py
@@ -33,9 +33,9 @@ def test_extract_question_with_date_filter_from_relative_day():
# Assert
expected_responses = [
- ('dt="1984-04-01"', ""),
- ('dt>="1984-04-01"', 'dt<"1984-04-02"'),
- ('dt>"1984-03-31"', 'dt<"1984-04-02"'),
+ ("dt='1984-04-01'", ""),
+ ("dt>='1984-04-01'", "dt<'1984-04-02'"),
+ ("dt>'1984-03-31'", "dt<'1984-04-02'"),
]
assert len(response) == 1
assert any([start in response[0] and end in response[0] for start, end in expected_responses]), (
@@ -51,7 +51,7 @@ def test_extract_question_with_date_filter_from_relative_month():
response = extract_questions("Which countries did I visit last month?")
# Assert
- expected_responses = [('dt>="1984-03-01"', 'dt<"1984-04-01"'), ('dt>="1984-03-01"', 'dt<="1984-03-31"')]
+ expected_responses = [("dt>='1984-03-01'", "dt<'1984-04-01'"), ("dt>='1984-03-01'", "dt<='1984-03-31'")]
assert len(response) == 1
assert any([start in response[0] and end in response[0] for start, end in expected_responses]), (
"Expected date filter to limit to March 1984 in response but got: " + response[0]
@@ -67,9 +67,9 @@ def test_extract_question_with_date_filter_from_relative_year():
# Assert
expected_responses = [
- ('dt>="1984-01-01"', ""),
- ('dt>="1984-01-01"', 'dt<"1985-01-01"'),
- ('dt>="1984-01-01"', 'dt<="1984-12-31"'),
+ ("dt>='1984-01-01'", ""),
+ ("dt>='1984-01-01'", "dt<'1985-01-01'"),
+ ("dt>='1984-01-01'", "dt<='1984-12-31'"),
]
assert len(response) == 1
assert any([start in response[0] and end in response[0] for start, end in expected_responses]), (
@@ -172,8 +172,8 @@ def test_generate_search_query_with_date_and_context_from_chat_history():
# Assert
expected_responses = [
- ('dt>="2000-04-01"', 'dt<"2000-05-01"'),
- ('dt>="2000-04-01"', 'dt<="2000-04-31"'),
+ ("dt>='2000-04-01'", "dt<'2000-05-01'"),
+ ("dt>='2000-04-01'", "dt<='2000-04-30'"),
]
assert len(response) == 1
assert "Masai Mara" in response[0]