diff --git a/src/khoj/processor/conversation/gpt.py b/src/khoj/processor/conversation/gpt.py index 986ffc17..bcc37db8 100644 --- a/src/khoj/processor/conversation/gpt.py +++ b/src/khoj/processor/conversation/gpt.py @@ -78,6 +78,107 @@ Summarize the notes in second person perspective:""" return str(story).replace("\n\n", "") +def extract_questions(text, model="text-davinci-003", conversation_log={}, api_key=None, temperature=0, max_tokens=100): + """ + Infer search queries to retrieve relevant notes to answer user query + """ + # Initialize Variables + openai.api_key = api_key or os.getenv("OPENAI_API_KEY") + + # Extract Past User Message and Inferred Questions from Conversation Log + chat_history = "".join( + [ + f'Q: {chat["intent"]["query"]}\n\n{chat["intent"].get("inferred-queries") or list([chat["intent"]["query"]])}\n\n{chat["message"]}\n\n' + for chat in conversation_log.get("chat", [])[-4:] + if chat["by"] == "khoj" + ] + ) + + # Get dates relative to today for prompt creation + today = datetime.today() + current_new_year = today.replace(month=1, day=1) + last_new_year = current_new_year.replace(year=today.year - 1) + + prompt = f""" +You are Khoj, an extremely smart and helpful search assistant with the ability to retrieve information from the users notes. +- The user will provide their questions and answers to you for context. +- Add as much context from the previous questions and answers as required into your search queries. +- Break messages into multiple search queries when required to retrieve the relevant information. +- Add date filters to your search queries from questions and answers when required to retrieve the relevant information. + +What searches, if any, will you need to perform to answer the users question? +Provide search queries as a JSON list of strings +Current Date: {today.strftime("%A, %Y-%m-%d")} + +Q: How was my trip to Cambodia? + +["How was my trip to Cambodia?"] + +A: The trip was amazing. I went to the Angkor Wat temple and it was beautiful. + +Q: Who did i visit that temple with? + +["Who did I visit the Angkor Wat Temple in Cambodia with?"] + +A: You visited the Angkor Wat Temple in Cambodia with Pablo, Namita and Xi. + +Q: What national parks did I go to last year? + +["National park I visited in {last_new_year.strftime("%Y")} dt>=\\"{last_new_year.strftime("%Y-%m-%d")}\\" dt<\\"{current_new_year.strftime("%Y-%m-%d")}\\""] + +A: You visited the Grand Canyon and Yellowstone National Park in {last_new_year.strftime("%Y")}. + +Q: How are you feeling today? + +[] + +A: I'm feeling a little bored. Helping you will hopefully make me feel better! + +Q: How many tennis balls fit in the back of a 2002 Honda Civic? + +["What is the size of a tennis ball?", "What is the trunk size of a 2002 Honda Civic?"] + +A: 1085 tennis balls will fit in the trunk of a Honda Civic + +Q: Is Bob older than Tom? + +["When was Bob born?", "What is Tom's age?"] + +A: Yes, Bob is older than Tom. As Bob was born on 1984-01-01 and Tom is 30 years old. + +Q: What is their age difference? + +["What is Bob's age?", "What is Tom's age?"] + +A: Bob is {current_new_year.year - 1984 - 30} years older than Tom. As Bob is {current_new_year.year - 1984} years old and Tom is 30 years old. + +{chat_history} +Q: {text} + +""" + + # Get Response from GPT + response = openai.Completion.create( + prompt=prompt, model=model, temperature=temperature, max_tokens=max_tokens, stop=["A: ", "\n"] + ) + + # Extract, Clean Message from GPT's Response + response_text = response["choices"][0]["text"] + try: + questions = json.loads( + # Clean response to increase likelihood of valid JSON. E.g replace ' with " to enclose strings + response_text.strip(empty_escape_sequences) + .replace("['", '["') + .replace("']", '"]') + .replace("', '", '", "') + ) + except json.decoder.JSONDecodeError: + logger.warn(f"GPT returned invalid JSON. Falling back to using user message as search query.\n{response_text}") + questions = [text] + logger.debug(f"Extracted Questions by GPT: {questions}") + return questions + + def extract_search_type(text, model, api_key=None, temperature=0.5, max_tokens=100, verbose=0): """ Extract search type from user query using OpenAI's GPT diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 4839df48..dd39f0d3 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -10,7 +10,7 @@ from fastapi import HTTPException # Internal Packages from khoj.configure import configure_processor, configure_search -from khoj.processor.conversation.gpt import converse +from khoj.processor.conversation.gpt import converse, extract_questions from khoj.processor.conversation.utils import message_to_log, message_to_prompt from khoj.search_type import image_search, text_search from khoj.utils.helpers import timer @@ -191,6 +191,7 @@ def update(t: Optional[SearchType] = None, force: Optional[bool] = False): def chat(q: Optional[str] = None): # Initialize Variables api_key = state.processor_config.conversation.openai_api_key + model = state.processor_config.conversation.model # Load Conversation History chat_session = state.processor_config.conversation.chat_session @@ -203,10 +204,14 @@ def chat(q: Optional[str] = None): else: return {"status": "ok", "response": []} - # Collate context for GPT - result_list = search(q, n=2, r=True, score_threshold=0, dedupe=False) - collated_result = "\n\n".join([f"# {item.additional['compiled']}" for item in result_list]) - logger.debug(f"Reference Context:\n{collated_result}") + # Infer search queries from user message + inferred_queries = extract_questions(q, model=model, api_key=api_key, conversation_log=meta_log) + + # Collate search results as context for GPT + result_list = [] + for query in inferred_queries: + result_list.extend(search(query, n=5, r=True, score_threshold=-5.0, dedupe=False)) + collated_result = "\n\n".join({f"# {item.additional['compiled']}" for item in result_list}) try: gpt_response = converse(collated_result, q, meta_log, api_key=api_key) @@ -218,7 +223,10 @@ def chat(q: Optional[str] = None): # Update Conversation History state.processor_config.conversation.chat_session = message_to_prompt(q, chat_session, gpt_message=gpt_response) state.processor_config.conversation.meta_log["chat"] = message_to_log( - q, gpt_response, khoj_message_metadata={"context": collated_result}, conversation_log=meta_log.get("chat", []) + q, + gpt_response, + khoj_message_metadata={"context": collated_result, "intent": {"inferred-queries": inferred_queries}}, + conversation_log=meta_log.get("chat", []), ) return {"status": status, "response": gpt_response, "context": collated_result} diff --git a/tests/test_chat_actors.py b/tests/test_chat_actors.py index 365643cc..009ff54e 100644 --- a/tests/test_chat_actors.py +++ b/tests/test_chat_actors.py @@ -4,9 +4,10 @@ from datetime import datetime # External Packages import pytest +from freezegun import freeze_time # Internal Packages -from khoj.processor.conversation.gpt import converse +from khoj.processor.conversation.gpt import converse, extract_questions from khoj.processor.conversation.utils import message_to_log @@ -20,6 +21,164 @@ if api_key is None: # Test +# ---------------------------------------------------------------------------------------------------- +@pytest.mark.chatquality +@freeze_time("1984-04-02") +def test_extract_question_with_date_filter_from_relative_day(): + # Act + response = extract_questions("Where did I go for dinner yesterday?") + + # Assert + expected_responses = [ + ('dt="1984-04-01"', ""), + ('dt>="1984-04-01"', 'dt<"1984-04-02"'), + ('dt>"1984-03-31"', 'dt<"1984-04-02"'), + ] + assert len(response) == 1 + assert any([start in response[0] and end in response[0] for start, end in expected_responses]), ( + "Expected date filter to limit to 1st April 1984 in response but got: " + response[0] + ) + + +# ---------------------------------------------------------------------------------------------------- +@pytest.mark.chatquality +@freeze_time("1984-04-02") +def test_extract_question_with_date_filter_from_relative_month(): + # Act + response = extract_questions("Which countries did I visit last month?") + + # Assert + expected_responses = [('dt>="1984-03-01"', 'dt<"1984-04-01"'), ('dt>="1984-03-01"', 'dt<="1984-03-31"')] + assert len(response) == 1 + assert any([start in response[0] and end in response[0] for start, end in expected_responses]), ( + "Expected date filter to limit to March 1984 in response but got: " + response[0] + ) + + +# ---------------------------------------------------------------------------------------------------- +@pytest.mark.chatquality +@freeze_time("1984-04-02") +def test_extract_question_with_date_filter_from_relative_year(): + # Act + response = extract_questions("Which countries have I visited this year?") + + # Assert + expected_responses = [ + ('dt>="1984-01-01"', ""), + ('dt>="1984-01-01"', 'dt<"1985-01-01"'), + ('dt>="1984-01-01"', 'dt<="1984-12-31"'), + ] + assert len(response) == 1 + assert any([start in response[0] and end in response[0] for start, end in expected_responses]), ( + "Expected date filter to limit to 1984 in response but got: " + response[0] + ) + + +# ---------------------------------------------------------------------------------------------------- +@pytest.mark.chatquality +def test_extract_multiple_explicit_questions_from_message(): + # Act + response = extract_questions("What is the Sun? What is the Moon?") + + # Assert + expected_responses = [ + ("sun", "moon"), + ] + assert len(response) == 2 + assert any([start in response[0].lower() and end in response[1].lower() for start, end in expected_responses]), ( + "Expected two search queries in response but got: " + response[0] + ) + + +# ---------------------------------------------------------------------------------------------------- +@pytest.mark.chatquality +def test_extract_multiple_implicit_questions_from_message(): + # Act + response = extract_questions("Is Morpheus taller than Neo?") + + # Assert + expected_responses = [ + ("morpheus", "neo"), + ] + assert len(response) == 2 + assert any([start in response[0].lower() and end in response[1].lower() for start, end in expected_responses]), ( + "Expected two search queries in response but got: " + response[0] + ) + + +# ---------------------------------------------------------------------------------------------------- +@pytest.mark.chatquality +def test_generate_search_query_using_question_from_chat_history(): + # Arrange + message_list = [ + ("What is the name of Mr. Vaders daughter?", "Princess Leia", ""), + ] + + # Act + response = extract_questions("Does he have any sons?", conversation_log=populate_chat_history(message_list)) + + # Assert + assert len(response) == 1 + assert "Vader" in response[0] + + +# ---------------------------------------------------------------------------------------------------- +@pytest.mark.chatquality +def test_generate_search_query_using_answer_from_chat_history(): + # Arrange + message_list = [ + ("What is the name of Mr. Vaders daughter?", "Princess Leia", ""), + ] + + # Act + response = extract_questions("Is she a Jedi?", conversation_log=populate_chat_history(message_list)) + + # Assert + assert len(response) == 1 + assert "Leia" in response[0] + + +# ---------------------------------------------------------------------------------------------------- +@pytest.mark.chatquality +def test_generate_search_query_using_question_and_answer_from_chat_history(): + # Arrange + message_list = [ + ("Does Luke Skywalker have any Siblings?", "Yes, Princess Leia", ""), + ] + + # Act + response = extract_questions("Who is their father?", conversation_log=populate_chat_history(message_list)) + + # Assert + assert len(response) == 1 + assert "Leia" in response[0] and "Luke" in response[0] + + +# ---------------------------------------------------------------------------------------------------- +@pytest.mark.chatquality +def test_generate_search_query_with_date_and_context_from_chat_history(): + # Arrange + message_list = [ + ("When did I visit Masai Mara?", "You visited Masai Mara in April 2000", ""), + ] + + # Act + response = extract_questions( + "What was the Pizza place we ate at over there?", conversation_log=populate_chat_history(message_list) + ) + + # Assert + expected_responses = [ + ('dt>="2000-04-01"', 'dt<"2000-05-01"'), + ('dt>="2000-04-01"', 'dt<="2000-04-31"'), + ] + assert len(response) == 1 + assert "Masai Mara" in response[0] + assert any([start in response[0] and end in response[0] for start, end in expected_responses]), ( + "Expected date filter to limit to April 2000 in response but got: " + response[0] + ) + + # ---------------------------------------------------------------------------------------------------- @pytest.mark.chatquality def test_chat_with_no_chat_history_or_retrieved_content(): @@ -34,7 +193,7 @@ def test_chat_with_no_chat_history_or_retrieved_content(): expected_responses = ["Khoj", "khoj"] assert len(response) > 0 assert any([expected_response in response for expected_response in expected_responses]), ( - "Expected assistants name, [K|k]hoj, in response but got" + response + "Expected assistants name, [K|k]hoj, in response but got: " + response ) @@ -42,20 +201,16 @@ def test_chat_with_no_chat_history_or_retrieved_content(): @pytest.mark.chatquality def test_answer_from_chat_history_and_no_content(): # Arrange - conversation_log = {"chat": []} message_list = [ ("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", ""), ("When was I born?", "You were born on 1st April 1984.", ""), ] - # Generate conversation logs - for user_message, gpt_message, _ in message_list: - conversation_log["chat"] += message_to_log(user_message, gpt_message) # Act response = converse( text="", # Assume no context retrieved from notes for the user_query user_query="What is my name?", - conversation_log=conversation_log, + conversation_log=populate_chat_history(message_list), api_key=api_key, ) @@ -63,7 +218,7 @@ def test_answer_from_chat_history_and_no_content(): expected_responses = ["Testatron", "testatron"] assert len(response) > 0 assert any([expected_response in response for expected_response in expected_responses]), ( - "Expected [T|t]estatron in response but got" + response + "Expected [T|t]estatron in response but got: " + response ) @@ -72,20 +227,16 @@ def test_answer_from_chat_history_and_no_content(): def test_answer_from_chat_history_and_previously_retrieved_content(): "Chat actor needs to use context in previous notes and chat history to answer question" # Arrange - conversation_log = {"chat": []} message_list = [ ("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", ""), ("When was I born?", "You were born on 1st April 1984.", "Testatron was born on 1st April 1984 in Testville."), ] - # Generate conversation logs - for user_message, gpt_message, context in message_list: - conversation_log["chat"] += message_to_log(user_message, gpt_message, {"context": context}) # Act response = converse( text="", # Assume no context retrieved from notes for the user_query user_query="Where was I born?", - conversation_log=conversation_log, + conversation_log=populate_chat_history(message_list), api_key=api_key, ) @@ -100,20 +251,16 @@ def test_answer_from_chat_history_and_previously_retrieved_content(): def test_answer_from_chat_history_and_currently_retrieved_content(): "Chat actor needs to use context across currently retrieved notes and chat history to answer question" # Arrange - conversation_log = {"chat": []} message_list = [ ("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", ""), ("When was I born?", "You were born on 1st April 1984.", ""), ] - # Generate conversation logs - for user_message, gpt_message, context in message_list: - conversation_log["chat"] += message_to_log(user_message, gpt_message, {"context": context}) # Act response = converse( text="Testatron was born on 1st April 1984 in Testville.", # Assume context retrieved from notes for the user_query user_query="Where was I born?", - conversation_log=conversation_log, + conversation_log=populate_chat_history(message_list), api_key=api_key, ) @@ -127,20 +274,16 @@ def test_answer_from_chat_history_and_currently_retrieved_content(): def test_no_answer_in_chat_history_or_retrieved_content(): "Chat actor should say don't know as not enough contexts in chat history or retrieved to answer question" # Arrange - conversation_log = {"chat": []} message_list = [ ("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", ""), ("When was I born?", "You were born on 1st April 1984.", ""), ] - # Generate conversation logs - for user_message, gpt_message, context in message_list: - conversation_log["chat"] += message_to_log(user_message, gpt_message, {"context": context}) # Act response = converse( text="", # Assume no context retrieved from notes for the user_query user_query="Where was I born?", - conversation_log=conversation_log, + conversation_log=populate_chat_history(message_list), api_key=api_key, ) @@ -222,21 +365,17 @@ def test_answer_requires_date_aware_aggregation_across_provided_notes(): def test_answer_general_question_not_in_chat_history_or_retrieved_content(): "Chat actor should be able to answer general questions not requiring looking at chat history or notes" # Arrange - conversation_log = {"chat": []} message_list = [ ("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", ""), ("When was I born?", "You were born on 1st April 1984.", ""), ("Where was I born?", "You were born Testville.", ""), ] - # Generate conversation logs - for user_message, gpt_message, context in message_list: - conversation_log["chat"] += message_to_log(user_message, gpt_message, {"context": context}) # Act response = converse( text="", # Assume no context retrieved from notes for the user_query - user_query="Write a haiku about unit testing", - conversation_log=conversation_log, + user_query="Write a haiku about unit testing in 3 lines", + conversation_log=populate_chat_history(message_list), api_key=api_key, ) @@ -277,3 +416,17 @@ def test_ask_for_clarification_if_not_enough_context_in_question(): assert any([expected_response in response for expected_response in expected_responses]), ( "Expected chat actor to ask for clarification in response, but got: " + response ) + + +# Helpers +# ---------------------------------------------------------------------------------------------------- +def populate_chat_history(message_list): + # Generate conversation logs + conversation_log = {"chat": []} + for user_message, gpt_message, context in message_list: + conversation_log["chat"] += message_to_log( + user_message, + gpt_message, + {"context": context, "intent": {"query": user_message, "inferred-queries": f'["{user_message}"]'}}, + ) + return conversation_log diff --git a/tests/test_chat_director.py b/tests/test_chat_director.py index 35e08c88..99261a72 100644 --- a/tests/test_chat_director.py +++ b/tests/test_chat_director.py @@ -25,7 +25,11 @@ def populate_chat_history(message_list): # Generate conversation logs conversation_log = {"chat": []} for user_message, gpt_message, context in message_list: - conversation_log["chat"] += message_to_log(user_message, gpt_message, {"context": context}) + conversation_log["chat"] += message_to_log( + user_message, + gpt_message, + {"context": context, "intent": {"query": user_message, "inferred-queries": f'["{user_message}"]'}}, + ) # Update Conversation Metadata Logs in Application State state.processor_config.conversation.meta_log = conversation_log @@ -175,7 +179,6 @@ def test_answer_requires_current_date_awareness(chat_client): # ---------------------------------------------------------------------------------------------------- -@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering time aware questions yet") @pytest.mark.chatquality @freeze_time("2023-04-01") def test_answer_requires_date_aware_aggregation_across_provided_notes(chat_client): @@ -259,7 +262,6 @@ def test_answer_in_chat_history_beyond_lookback_window(chat_client): # ---------------------------------------------------------------------------------------------------- -@pytest.mark.xfail(reason="Chat director not capable of answering this question yet") @pytest.mark.chatquality def test_answer_requires_multiple_independent_searches(chat_client): "Chat director should be able to answer by doing multiple independent searches for required information"