Have Khoj dynamically select conversation command(s) in chat (#641)

* Have Khoj dynamically select which conversation command(s) are to be used in the chat flow
- Intercept the commands if in default mode, and have Khoj dynamically guess which tools would be the most relevant for answering the user's query
* Remove conditional for default to enter online search mode
* Add multiple-tool examples in the prompt, make prompt for tools more specific to info collection
This commit is contained in:
sabaimran 2024-02-11 03:41:32 -08:00 committed by GitHub
parent 69344a6aa6
commit a3eb17b7d4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 373 additions and 63 deletions

View file

@ -479,7 +479,7 @@ class ConversationAdapters:
conversation_id: int = None, conversation_id: int = None,
user_message: str = None, user_message: str = None,
): ):
slug = user_message.strip()[:200] if not is_none_or_empty(user_message) else None slug = user_message.strip()[:200] if user_message else None
if conversation_id: if conversation_id:
conversation = Conversation.objects.filter(user=user, client=client_application, id=conversation_id) conversation = Conversation.objects.filter(user=user, client=client_application, id=conversation_id)
else: else:

View file

@ -130,7 +130,7 @@ def converse_offline(
model: str = "mistral-7b-instruct-v0.1.Q4_0.gguf", model: str = "mistral-7b-instruct-v0.1.Q4_0.gguf",
loaded_model: Union[Any, None] = None, loaded_model: Union[Any, None] = None,
completion_func=None, completion_func=None,
conversation_command=ConversationCommand.Default, conversation_commands=[ConversationCommand.Default],
max_prompt_size=None, max_prompt_size=None,
tokenizer_name=None, tokenizer_name=None,
) -> Union[ThreadedGenerator, Iterator[str]]: ) -> Union[ThreadedGenerator, Iterator[str]]:
@ -148,27 +148,24 @@ def converse_offline(
# Initialize Variables # Initialize Variables
compiled_references_message = "\n\n".join({f"{item}" for item in references}) compiled_references_message = "\n\n".join({f"{item}" for item in references})
conversation_primer = prompts.query_prompt.format(query=user_query)
# Get Conversation Primer appropriate to Conversation Type # Get Conversation Primer appropriate to Conversation Type
if conversation_command == ConversationCommand.Notes and is_none_or_empty(compiled_references_message): if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(compiled_references_message):
return iter([prompts.no_notes_found.format()]) return iter([prompts.no_notes_found.format()])
elif conversation_command == ConversationCommand.Online and is_none_or_empty(online_results): elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
completion_func(chat_response=prompts.no_online_results_found.format()) completion_func(chat_response=prompts.no_online_results_found.format())
return iter([prompts.no_online_results_found.format()]) return iter([prompts.no_online_results_found.format()])
elif conversation_command == ConversationCommand.Online:
if ConversationCommand.Online in conversation_commands:
simplified_online_results = online_results.copy() simplified_online_results = online_results.copy()
for result in online_results: for result in online_results:
if online_results[result].get("extracted_content"): if online_results[result].get("extracted_content"):
simplified_online_results[result] = online_results[result]["extracted_content"] simplified_online_results[result] = online_results[result]["extracted_content"]
conversation_primer = prompts.online_search_conversation.format( conversation_primer = f"{prompts.online_search_conversation.format(online_results=str(simplified_online_results))}\n{conversation_primer}"
query=user_query, online_results=str(simplified_online_results) if ConversationCommand.Notes in conversation_commands:
) conversation_primer = f"{prompts.notes_conversation_gpt4all.format(references=compiled_references_message)}\n{conversation_primer}"
elif conversation_command == ConversationCommand.General or is_none_or_empty(compiled_references_message):
conversation_primer = user_query
else:
conversation_primer = prompts.notes_conversation_gpt4all.format(
query=user_query, references=compiled_references_message
)
# Setup Prompt with Primer or Conversation History # Setup Prompt with Primer or Conversation History
current_date = datetime.now().strftime("%Y-%m-%d") current_date = datetime.now().strftime("%Y-%m-%d")

View file

@ -122,7 +122,7 @@ def converse(
api_key: Optional[str] = None, api_key: Optional[str] = None,
temperature: float = 0.2, temperature: float = 0.2,
completion_func=None, completion_func=None,
conversation_command=ConversationCommand.Default, conversation_commands=[ConversationCommand.Default],
max_prompt_size=None, max_prompt_size=None,
tokenizer_name=None, tokenizer_name=None,
): ):
@ -133,26 +133,25 @@ def converse(
current_date = datetime.now().strftime("%Y-%m-%d") current_date = datetime.now().strftime("%Y-%m-%d")
compiled_references = "\n\n".join({f"# {item}" for item in references}) compiled_references = "\n\n".join({f"# {item}" for item in references})
conversation_primer = prompts.query_prompt.format(query=user_query)
# Get Conversation Primer appropriate to Conversation Type # Get Conversation Primer appropriate to Conversation Type
if conversation_command == ConversationCommand.Notes and is_none_or_empty(compiled_references): if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(compiled_references):
completion_func(chat_response=prompts.no_notes_found.format()) completion_func(chat_response=prompts.no_notes_found.format())
return iter([prompts.no_notes_found.format()]) return iter([prompts.no_notes_found.format()])
elif conversation_command == ConversationCommand.Online and is_none_or_empty(online_results): elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
completion_func(chat_response=prompts.no_online_results_found.format()) completion_func(chat_response=prompts.no_online_results_found.format())
return iter([prompts.no_online_results_found.format()]) return iter([prompts.no_online_results_found.format()])
elif conversation_command == ConversationCommand.Online:
if ConversationCommand.Online in conversation_commands:
simplified_online_results = online_results.copy() simplified_online_results = online_results.copy()
for result in online_results: for result in online_results:
if online_results[result].get("extracted_content"): if online_results[result].get("extracted_content"):
simplified_online_results[result] = online_results[result]["extracted_content"] simplified_online_results[result] = online_results[result]["extracted_content"]
conversation_primer = prompts.online_search_conversation.format( conversation_primer = f"{prompts.online_search_conversation.format(online_results=str(simplified_online_results))}\n{conversation_primer}"
query=user_query, online_results=str(simplified_online_results) if ConversationCommand.Notes in conversation_commands:
) conversation_primer = f"{prompts.notes_conversation.format(query=user_query, references=compiled_references)}\n{conversation_primer}"
elif conversation_command == ConversationCommand.General or is_none_or_empty(compiled_references):
conversation_primer = prompts.general_conversation.format(query=user_query)
else:
conversation_primer = prompts.notes_conversation.format(query=user_query, references=compiled_references)
# Setup Prompt with Primer or Conversation History # Setup Prompt with Primer or Conversation History
messages = generate_chatml_messages_with_context( messages = generate_chatml_messages_with_context(

View file

@ -112,7 +112,6 @@ notes_conversation_gpt4all = PromptTemplate.from_template(
""" """
User's Notes: User's Notes:
{references} {references}
Question: {query}
""".strip() """.strip()
) )
@ -139,7 +138,13 @@ Use this up-to-date information from the internet to inform your response.
Ask crisp follow-up questions to get additional context, when a helpful response cannot be provided from the online data or past conversations. Ask crisp follow-up questions to get additional context, when a helpful response cannot be provided from the online data or past conversations.
Information from the internet: {online_results} Information from the internet: {online_results}
""".strip()
)
## Query prompt
## --
query_prompt = PromptTemplate.from_template(
"""
Query: {query}""".strip() Query: {query}""".strip()
) )
@ -285,6 +290,60 @@ Collate the relevant information from the website to answer the target query.
""".strip() """.strip()
) )
pick_relevant_information_collection_tools = PromptTemplate.from_template(
"""
You are Khoj, a smart and helpful personal assistant. You have access to a variety of data sources to help you answer the user's question. You can use the data sources listed below to collect more relevant information. You can use any combination of these data sources to answer the user's question. Tell me which data sources you would like to use to answer the user's question.
{tools}
Here are some example responses:
Example 1:
Chat History:
User: I'm thinking of moving to a new city. I'm trying to decide between New York and San Francisco.
AI: Moving to a new city can be challenging. Both New York and San Francisco are great cities to live in. New York is known for its diverse culture and San Francisco is known for its tech scene.
Q: What is the population of each of those cities?
Khoj: ["online"]
Example 2:
Chat History:
User: I've been having a hard time at work. I'm thinking of quitting.
AI: I'm sorry to hear that. It's important to take care of your mental health. Have you considered talking to your manager about your concerns?
Q: What are the best ways to quit a job?
Khoj: ["general"]
Example 3:
Chat History:
User: I'm thinking of my next vacation idea. Ideally, I want to see something new and exciting.
AI: Excellent! Taking a vacation is a great way to relax and recharge.
Q: Where did Grandma grow up?
Khoj: ["notes"]
Example 4:
Chat History:
Q: I want to make chocolate cake. What was my recipe?
Khoj: ["notes"]
Example 5:
Chat History:
Q: What's the latest news with the first company I worked for?
Khoj: ["notes", "online"]
Now it's your turn to pick the tools you would like to use to answer the user's question. Provide your response as a list of strings.
Chat History:
{chat_history}
Q: {query}
A:
""".strip()
)
online_search_conversation_subqueries = PromptTemplate.from_template( online_search_conversation_subqueries = PromptTemplate.from_template(
""" """
You are Khoj, an extremely smart and helpful search assistant. You are tasked with constructing **up to three** search queries for Google to answer the user's question. You are Khoj, an extremely smart and helpful search assistant. You are tasked with constructing **up to three** search queries for Google to answer the user's question.

View file

@ -274,7 +274,7 @@ async def extract_references_and_questions(
q: str, q: str,
n: int, n: int,
d: float, d: float,
conversation_type: ConversationCommand = ConversationCommand.Default, conversation_commands: List[ConversationCommand] = [ConversationCommand.Default],
): ):
user = request.user.object if request.user.is_authenticated else None user = request.user.object if request.user.is_authenticated else None
@ -282,7 +282,7 @@ async def extract_references_and_questions(
compiled_references: List[Any] = [] compiled_references: List[Any] = []
inferred_queries: List[str] = [] inferred_queries: List[str] = []
if conversation_type == ConversationCommand.General or conversation_type == ConversationCommand.Online: if not ConversationCommand.Notes in conversation_commands:
return compiled_references, inferred_queries, q return compiled_references, inferred_queries, q
if not await sync_to_async(EntryAdapters.user_has_entries)(user=user): if not await sync_to_async(EntryAdapters.user_has_entries)(user=user):

View file

@ -21,6 +21,7 @@ from khoj.routers.helpers import (
CommonQueryParams, CommonQueryParams,
ConversationCommandRateLimiter, ConversationCommandRateLimiter,
agenerate_chat_response, agenerate_chat_response,
aget_relevant_information_sources,
get_conversation_command, get_conversation_command,
is_ready_to_chat, is_ready_to_chat,
text_to_image, text_to_image,
@ -207,7 +208,7 @@ async def set_conversation_title(
) )
@api_chat.get("", response_class=Response) @api_chat.get("/", response_class=Response)
@requires(["authenticated"]) @requires(["authenticated"])
async def chat( async def chat(
request: Request, request: Request,
@ -229,25 +230,9 @@ async def chat(
q = unquote(q) q = unquote(q)
await is_ready_to_chat(user) await is_ready_to_chat(user)
conversation_command = get_conversation_command(query=q, any_references=True) conversation_commands = [get_conversation_command(query=q, any_references=True)]
await conversation_command_rate_limiter.update_and_check_if_valid(request, conversation_command) if conversation_commands == [ConversationCommand.Help]:
q = q.replace(f"/{conversation_command.value}", "").strip()
meta_log = (
await ConversationAdapters.aget_conversation_by_user(user, request.user.client_app, conversation_id, slug)
).conversation_log
compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
request, common, meta_log, q, (n or 5), (d or math.inf), conversation_command
)
online_results: Dict = dict()
if conversation_command == ConversationCommand.Default and is_none_or_empty(compiled_references):
conversation_command = ConversationCommand.General
elif conversation_command == ConversationCommand.Help:
conversation_config = await ConversationAdapters.aget_user_conversation_config(user) conversation_config = await ConversationAdapters.aget_user_conversation_config(user)
if conversation_config == None: if conversation_config == None:
conversation_config = await ConversationAdapters.aget_default_conversation_config() conversation_config = await ConversationAdapters.aget_default_conversation_config()
@ -255,7 +240,23 @@ async def chat(
formatted_help = help_message.format(model=model_type, version=state.khoj_version, device=get_device()) formatted_help = help_message.format(model=model_type, version=state.khoj_version, device=get_device())
return StreamingResponse(iter([formatted_help]), media_type="text/event-stream", status_code=200) return StreamingResponse(iter([formatted_help]), media_type="text/event-stream", status_code=200)
elif conversation_command == ConversationCommand.Notes and not await EntryAdapters.auser_has_entries(user): meta_log = (
await ConversationAdapters.aget_conversation_by_user(user, request.user.client_app, conversation_id, slug)
).conversation_log
if conversation_commands == [ConversationCommand.Default]:
conversation_commands = await aget_relevant_information_sources(q, meta_log)
for cmd in conversation_commands:
await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd)
q = q.replace(f"/{cmd.value}", "").strip()
compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
request, common, meta_log, q, (n or 5), (d or math.inf), conversation_commands
)
online_results: Dict = dict()
if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user):
no_entries_found_format = no_entries_found.format() no_entries_found_format = no_entries_found.format()
if stream: if stream:
return StreamingResponse(iter([no_entries_found_format]), media_type="text/event-stream", status_code=200) return StreamingResponse(iter([no_entries_found_format]), media_type="text/event-stream", status_code=200)
@ -263,7 +264,10 @@ async def chat(
response_obj = {"response": no_entries_found_format} response_obj = {"response": no_entries_found_format}
return Response(content=json.dumps(response_obj), media_type="text/plain", status_code=200) return Response(content=json.dumps(response_obj), media_type="text/plain", status_code=200)
elif conversation_command == ConversationCommand.Online: if ConversationCommand.Notes in conversation_commands and is_none_or_empty(compiled_references):
conversation_commands.remove(ConversationCommand.Notes)
if ConversationCommand.Online in conversation_commands:
try: try:
online_results = await search_with_google(defiltered_query, meta_log) online_results = await search_with_google(defiltered_query, meta_log)
except ValueError as e: except ValueError as e:
@ -272,12 +276,12 @@ async def chat(
media_type="text/event-stream", media_type="text/event-stream",
status_code=200, status_code=200,
) )
elif conversation_command == ConversationCommand.Image: elif conversation_commands == [ConversationCommand.Image]:
update_telemetry_state( update_telemetry_state(
request=request, request=request,
telemetry_type="api", telemetry_type="api",
api="chat", api="chat",
metadata={"conversation_command": conversation_command.value}, metadata={"conversation_command": conversation_commands[0].value},
**common.__dict__, **common.__dict__,
) )
image, status_code, improved_image_prompt = await text_to_image(q, meta_log) image, status_code, improved_image_prompt = await text_to_image(q, meta_log)
@ -308,13 +312,13 @@ async def chat(
compiled_references, compiled_references,
online_results, online_results,
inferred_queries, inferred_queries,
conversation_command, conversation_commands,
user, user,
request.user.client_app, request.user.client_app,
conversation_id, conversation_id,
) )
chat_metadata.update({"conversation_command": conversation_command.value}) chat_metadata.update({"conversation_command": ",".join([cmd.value for cmd in conversation_commands])})
update_telemetry_state( update_telemetry_state(
request=request, request=request,

View file

@ -34,7 +34,11 @@ from khoj.processor.conversation.utils import (
) )
from khoj.utils import state from khoj.utils import state
from khoj.utils.config import GPT4AllProcessorModel from khoj.utils.config import GPT4AllProcessorModel
from khoj.utils.helpers import ConversationCommand, log_telemetry from khoj.utils.helpers import (
ConversationCommand,
log_telemetry,
tool_descriptions_for_llm,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -105,6 +109,15 @@ def update_telemetry_state(
] ]
def construct_chat_history(conversation_history: dict, n: int = 4) -> str:
chat_history = ""
for chat in conversation_history.get("chat", [])[-n:]:
if chat["by"] == "khoj" and chat["intent"].get("type") == "remember":
chat_history += f"User: {chat['intent']['query']}\n"
chat_history += f"Khoj: {chat['message']}\n"
return chat_history
def get_conversation_command(query: str, any_references: bool = False) -> ConversationCommand: def get_conversation_command(query: str, any_references: bool = False) -> ConversationCommand:
if query.startswith("/notes"): if query.startswith("/notes"):
return ConversationCommand.Notes return ConversationCommand.Notes
@ -128,15 +141,50 @@ async def agenerate_chat_response(*args):
return await loop.run_in_executor(executor, generate_chat_response, *args) return await loop.run_in_executor(executor, generate_chat_response, *args)
async def aget_relevant_information_sources(query: str, conversation_history: dict):
"""
Given a query, determine which of the available tools the agent should use in order to answer appropriately.
"""
tool_options = dict()
for tool, description in tool_descriptions_for_llm.items():
tool_options[tool.value] = description
chat_history = construct_chat_history(conversation_history)
relevant_tools_prompt = prompts.pick_relevant_information_collection_tools.format(
query=query,
tools=str(tool_options),
chat_history=chat_history,
)
response = await send_message_to_model_wrapper(relevant_tools_prompt)
try:
response = response.strip()
response = json.loads(response)
response = [q.strip() for q in response if q.strip()]
if not isinstance(response, list) or not response or len(response) == 0:
logger.error(f"Invalid response for determining relevant tools: {response}")
return tool_options
final_response = []
for llm_suggested_tool in response:
if llm_suggested_tool in tool_options.keys():
# Check whether the tool exists as a valid ConversationCommand
final_response.append(ConversationCommand(llm_suggested_tool))
return final_response
except Exception as e:
logger.error(f"Invalid response for determining relevant tools: {response}")
return [ConversationCommand.Default]
async def generate_online_subqueries(q: str, conversation_history: dict) -> List[str]: async def generate_online_subqueries(q: str, conversation_history: dict) -> List[str]:
""" """
Generate subqueries from the given query Generate subqueries from the given query
""" """
chat_history = "" chat_history = construct_chat_history(conversation_history)
for chat in conversation_history.get("chat", [])[-4:]:
if chat["by"] == "khoj" and chat["intent"].get("type") == "remember":
chat_history += f"User: {chat['intent']['query']}\n"
chat_history += f"Khoj: {chat['message']}\n"
utc_date = datetime.utcnow().strftime("%Y-%m-%d") utc_date = datetime.utcnow().strftime("%Y-%m-%d")
online_queries_prompt = prompts.online_search_conversation_subqueries.format( online_queries_prompt = prompts.online_search_conversation_subqueries.format(
@ -241,14 +289,14 @@ def generate_chat_response(
compiled_references: List[str] = [], compiled_references: List[str] = [],
online_results: Dict[str, Any] = {}, online_results: Dict[str, Any] = {},
inferred_queries: List[str] = [], inferred_queries: List[str] = [],
conversation_command: ConversationCommand = ConversationCommand.Default, conversation_commands: List[ConversationCommand] = [ConversationCommand.Default],
user: KhojUser = None, user: KhojUser = None,
client_application: ClientApplication = None, client_application: ClientApplication = None,
conversation_id: int = None, conversation_id: int = None,
) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]: ) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]:
# Initialize Variables # Initialize Variables
chat_response = None chat_response = None
logger.debug(f"Conversation Type: {conversation_command.name}") logger.debug(f"Conversation Types: {conversation_commands}")
metadata = {} metadata = {}
@ -278,7 +326,7 @@ def generate_chat_response(
loaded_model=loaded_model, loaded_model=loaded_model,
conversation_log=meta_log, conversation_log=meta_log,
completion_func=partial_completion, completion_func=partial_completion,
conversation_command=conversation_command, conversation_commands=conversation_commands,
model=conversation_config.chat_model, model=conversation_config.chat_model,
max_prompt_size=conversation_config.max_prompt_size, max_prompt_size=conversation_config.max_prompt_size,
tokenizer_name=conversation_config.tokenizer, tokenizer_name=conversation_config.tokenizer,
@ -296,7 +344,7 @@ def generate_chat_response(
model=chat_model, model=chat_model,
api_key=api_key, api_key=api_key,
completion_func=partial_completion, completion_func=partial_completion,
conversation_command=conversation_command, conversation_commands=conversation_commands,
max_prompt_size=conversation_config.max_prompt_size, max_prompt_size=conversation_config.max_prompt_size,
tokenizer_name=conversation_config.tokenizer, tokenizer_name=conversation_config.tokenizer,
) )

View file

@ -282,6 +282,13 @@ command_descriptions = {
ConversationCommand.Help: "Display a help message with all available commands and other metadata.", ConversationCommand.Help: "Display a help message with all available commands and other metadata.",
} }
tool_descriptions_for_llm = {
ConversationCommand.Default: "Use this if there might be a mix of general and personal knowledge in the question",
ConversationCommand.General: "Use this when you can answer the question without needing any additional online or personal information",
ConversationCommand.Notes: "Use this when you would like to use the user's personal knowledge base to answer the question",
ConversationCommand.Online: "Use this when you would like to look up information on the internet",
}
def generate_random_name(): def generate_random_name():
# List of adjectives and nouns to choose from # List of adjectives and nouns to choose from

View file

@ -8,6 +8,7 @@ from freezegun import freeze_time
from khoj.processor.conversation import prompts from khoj.processor.conversation import prompts
from khoj.processor.conversation.utils import message_to_log from khoj.processor.conversation.utils import message_to_log
from khoj.routers.helpers import aget_relevant_information_sources
from tests.helpers import ConversationFactory from tests.helpers import ConversationFactory
SKIP_TESTS = True SKIP_TESTS = True
@ -440,3 +441,100 @@ def test_answer_requires_multiple_independent_searches(client_offline_chat):
assert any([expected_response in response_message.lower() for expected_response in expected_responses]), ( assert any([expected_response in response_message.lower() for expected_response in expected_responses]), (
"Expected Xi is older than Namita, but got: " + response_message "Expected Xi is older than Namita, but got: " + response_message
) )
# ----------------------------------------------------------------------------------------------------
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
async def test_get_correct_tools_online(client_offline_chat):
# Arrange
user_query = "What's the weather in Patagonia this week?"
# Act
tools = await aget_relevant_information_sources(user_query, {})
# Assert
tools = [tool.value for tool in tools]
assert tools == ["online"]
# ----------------------------------------------------------------------------------------------------
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
async def test_get_correct_tools_notes(client_offline_chat):
# Arrange
user_query = "Where did I go for my first battleship training?"
# Act
tools = await aget_relevant_information_sources(user_query, {})
# Assert
tools = [tool.value for tool in tools]
assert tools == ["notes"]
# ----------------------------------------------------------------------------------------------------
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
async def test_get_correct_tools_online_or_general_and_notes(client_offline_chat):
# Arrange
user_query = "What's the highest point in Patagonia and have I been there?"
# Act
tools = await aget_relevant_information_sources(user_query, {})
# Assert
tools = [tool.value for tool in tools]
assert len(tools) == 2
assert "online" or "general" in tools
assert "notes" in tools
# ----------------------------------------------------------------------------------------------------
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
async def test_get_correct_tools_general(client_offline_chat):
# Arrange
user_query = "How many noble gases are there?"
# Act
tools = await aget_relevant_information_sources(user_query, {})
# Assert
tools = [tool.value for tool in tools]
assert tools == ["general"]
# ----------------------------------------------------------------------------------------------------
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
async def test_get_correct_tools_with_chat_history(client_offline_chat):
# Arrange
user_query = "What's the latest in the Israel/Palestine conflict?"
chat_log = [
(
"Let's talk about the current events around the world.",
"Sure, let's discuss the current events. What would you like to know?",
),
("What's up in New York City?", "A Pride parade has recently been held in New York City, on July 31st."),
]
chat_history = populate_chat_history(chat_log)
# Act
tools = await aget_relevant_information_sources(user_query, chat_history)
# Assert
tools = [tool.value for tool in tools]
assert tools == ["online"]
def populate_chat_history(message_list):
# Generate conversation logs
conversation_log = {"chat": []}
for user_message, gpt_message in message_list:
conversation_log["chat"] += message_to_log(
user_message,
gpt_message,
{"context": [], "intent": {"query": user_message, "inferred-queries": f'["{user_message}"]'}},
)
return conversation_log

View file

@ -8,6 +8,7 @@ from freezegun import freeze_time
from khoj.database.models import KhojUser from khoj.database.models import KhojUser
from khoj.processor.conversation import prompts from khoj.processor.conversation import prompts
from khoj.processor.conversation.utils import message_to_log from khoj.processor.conversation.utils import message_to_log
from khoj.routers.helpers import aget_relevant_information_sources
from tests.helpers import ConversationFactory from tests.helpers import ConversationFactory
# Initialize variables for tests # Initialize variables for tests
@ -416,3 +417,100 @@ def test_answer_using_file_filter(chat_client):
assert any([expected_response in response_message.lower() for expected_response in expected_responses]), ( assert any([expected_response in response_message.lower() for expected_response in expected_responses]), (
"Expected Xi is older than Namita, but got: " + response_message "Expected Xi is older than Namita, but got: " + response_message
) )
# ----------------------------------------------------------------------------------------------------
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
async def test_get_correct_tools_online(chat_client):
# Arrange
user_query = "What's the weather in Patagonia this week?"
# Act
tools = await aget_relevant_information_sources(user_query, {})
# Assert
tools = [tool.value for tool in tools]
assert tools == ["online"]
# ----------------------------------------------------------------------------------------------------
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
async def test_get_correct_tools_notes(chat_client):
# Arrange
user_query = "Where did I go for my first battleship training?"
# Act
tools = await aget_relevant_information_sources(user_query, {})
# Assert
tools = [tool.value for tool in tools]
assert tools == ["notes"]
# ----------------------------------------------------------------------------------------------------
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
async def test_get_correct_tools_online_or_general_and_notes(chat_client):
# Arrange
user_query = "What's the highest point in Patagonia and have I been there?"
# Act
tools = await aget_relevant_information_sources(user_query, {})
# Assert
tools = [tool.value for tool in tools]
assert len(tools) == 2
assert "online" in tools or "general" in tools
assert "notes" in tools
# ----------------------------------------------------------------------------------------------------
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
async def test_get_correct_tools_general(chat_client):
# Arrange
user_query = "How many noble gases are there?"
# Act
tools = await aget_relevant_information_sources(user_query, {})
# Assert
tools = [tool.value for tool in tools]
assert tools == ["general"]
# ----------------------------------------------------------------------------------------------------
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
async def test_get_correct_tools_with_chat_history(chat_client):
# Arrange
user_query = "What's the latest in the Israel/Palestine conflict?"
chat_log = [
(
"Let's talk about the current events around the world.",
"Sure, let's discuss the current events. What would you like to know?",
),
("What's up in New York City?", "A Pride parade has recently been held in New York City, on July 31st."),
]
chat_history = populate_chat_history(chat_log)
# Act
tools = await aget_relevant_information_sources(user_query, chat_history)
# Assert
tools = [tool.value for tool in tools]
assert tools == ["online"]
def populate_chat_history(message_list):
# Generate conversation logs
conversation_log = {"chat": []}
for user_message, gpt_message in message_list:
conversation_log["chat"] += message_to_log(
user_message,
gpt_message,
{"context": [], "intent": {"query": user_message, "inferred-queries": f'["{user_message}"]'}},
)
return conversation_log