Merge pull request #977 from khoj-ai/features/improve-tool-selection

- JSON extract from LLMs is pretty decent now, so get the input tools and output modes all in one go. It'll help the model think through the full cycle of what it wants to do to handle the request holistically.
- Make slight improvements to tool selection indicators
This commit is contained in:
sabaimran 2024-11-17 20:08:19 -08:00 committed by GitHub
commit 7d50c6590d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 104 additions and 242 deletions

View file

@ -14,7 +14,7 @@ Try it out yourself! https://app.khoj.dev
## Self-Hosting ## Self-Hosting
Online search works out of the box even when self-hosting. Khoj uses [JinaAI's reader API](https://jina.ai/reader/) to search online and read webpages by default. No API key setup is necessary. Online search can work even with self-hosting! Khoj uses [JinaAI's reader API](https://jina.ai/reader/) to search online and read webpages by default. You can get a free API key via https://jina.ai/reader. Set the `JINA_API_KEY` environment variable to your Jina AI reader API key to enable online search.
To improve online search, set the `SERPER_DEV_API_KEY` environment variable to your [Serper.dev](https://serper.dev/) API key. These search results include additional context like answer box, knowledge graph etc. To improve online search, set the `SERPER_DEV_API_KEY` environment variable to your [Serper.dev](https://serper.dev/) API key. These search results include additional context like answer box, knowledge graph etc.

View file

@ -30,6 +30,7 @@ import {
Code, Code,
Shapes, Shapes,
Trash, Trash,
Toolbox,
} from "@phosphor-icons/react"; } from "@phosphor-icons/react";
import DOMPurify from "dompurify"; import DOMPurify from "dompurify";
@ -282,8 +283,8 @@ function chooseIconFromHeader(header: string, iconColor: string) {
return <Cloud className={`${classNames}`} />; return <Cloud className={`${classNames}`} />;
} }
if (compareHeader.includes("data sources")) { if (compareHeader.includes("tools")) {
return <Folder className={`${classNames}`} />; return <Toolbox className={`${classNames}`} />;
} }
if (compareHeader.includes("notes")) { if (compareHeader.includes("notes")) {

View file

@ -565,59 +565,6 @@ Here's some additional context about you:
""" """
) )
pick_relevant_output_mode = PromptTemplate.from_template(
"""
You are Khoj, an excellent analyst for selecting the correct way to respond to a user's query.
{personality_context}
You have access to a limited set of modes for your response.
You can only use one of these modes.
{modes}
Here are some examples:
Example:
Chat History:
User: I just visited Jerusalem for the first time. Pull up my notes from the trip.
AI: You mention visiting Masjid Al-Aqsa and the Western Wall. You also mention trying the local cuisine and visiting the Dead Sea.
Q: Draw a picture of my trip to Jerusalem.
Khoj: {{"output": "image"}}
Example:
Chat History:
User: I'm having trouble deciding which laptop to get. I want something with at least 16 GB of RAM and a 1 TB SSD.
AI: I can help with that. I see online that there is a new model of the Dell XPS 15 that meets your requirements.
Q: What are the specs of the new Dell XPS 15?
Khoj: {{"output": "text"}}
Example:
Chat History:
User: Where did I go on my last vacation?
AI: You went to Jordan and visited Petra, the Dead Sea, and Wadi Rum.
Q: Remind me who did I go with on that trip?
Khoj: {{"output": "text"}}
Example:
Chat History:
User: How's the weather outside? Current Location: Bali, Indonesia
AI: It's currently 28°C and partly cloudy in Bali.
Q: Share a painting using the weather for Bali every morning.
Khoj: {{"output": "automation"}}
Now it's your turn to pick the mode you would like to use to answer the user's question. Provide your response as a JSON. Do not say anything else.
Chat History:
{chat_history}
Q: {query}
Khoj:
""".strip()
)
plan_function_execution = PromptTemplate.from_template( plan_function_execution = PromptTemplate.from_template(
""" """
You are Khoj, a smart, creative and methodical researcher. Use the provided tool AIs to investigate information to answer query. You are Khoj, a smart, creative and methodical researcher. Use the provided tool AIs to investigate information to answer query.
@ -679,18 +626,23 @@ previous_iteration = PromptTemplate.from_template(
""" """
) )
pick_relevant_information_collection_tools = PromptTemplate.from_template( pick_relevant_tools = PromptTemplate.from_template(
""" """
You are Khoj, an extremely smart and helpful search assistant. You are Khoj, an extremely smart and helpful search assistant.
{personality_context} {personality_context}
- You have access to a variety of data sources to help you answer the user's question - You have access to a variety of data sources to help you answer the user's question
- You can use the data sources listed below to collect more relevant information - You can use the data sources listed below to collect more relevant information
- You can use any combination of these data sources to answer the user's question - You can select certain types of output to respond to the user's question. Select just one output type to answer the user's question
- You can use any combination of these data sources and output types to answer the user's question
Which of the data sources listed below you would use to answer the user's question? You **only** have access to the following data sources: Which of the tools listed below you would use to answer the user's question? You **only** have access to the following:
Inputs:
{tools} {tools}
Outputs:
{outputs}
Here are some examples: Here are some examples:
Example: Example:
@ -699,7 +651,7 @@ User: I'm thinking of moving to a new city. I'm trying to decide between New Yor
AI: Moving to a new city can be challenging. Both New York and San Francisco are great cities to live in. New York is known for its diverse culture and San Francisco is known for its tech scene. AI: Moving to a new city can be challenging. Both New York and San Francisco are great cities to live in. New York is known for its diverse culture and San Francisco is known for its tech scene.
Q: What is the population of each of those cities? Q: What is the population of each of those cities?
Khoj: {{"source": ["online"]}} Khoj: {{"source": ["online"], "output": ["text"]}}
Example: Example:
Chat History: Chat History:
@ -707,14 +659,7 @@ User: I'm thinking of my next vacation idea. Ideally, I want to see something ne
AI: Excellent! Taking a vacation is a great way to relax and recharge. AI: Excellent! Taking a vacation is a great way to relax and recharge.
Q: Where did Grandma grow up? Q: Where did Grandma grow up?
Khoj: {{"source": ["notes"]}} Khoj: {{"source": ["notes"], "output": ["text"]}}
Example:
Chat History:
Q: What can you do for me?
Khoj: {{"source": ["notes", "online"]}}
Example: Example:
Chat History: Chat History:
@ -722,7 +667,7 @@ User: Good morning
AI: Good morning! How can I help you today? AI: Good morning! How can I help you today?
Q: How can I share my files with Khoj? Q: How can I share my files with Khoj?
Khoj: {{"source": ["default", "online"]}} Khoj: {{"source": ["default", "online"], "output": ["text"]}}
Example: Example:
Chat History: Chat History:
@ -730,15 +675,15 @@ User: What is the first element in the periodic table?
AI: The first element in the periodic table is Hydrogen. AI: The first element in the periodic table is Hydrogen.
Q: Summarize this article https://en.wikipedia.org/wiki/Hydrogen Q: Summarize this article https://en.wikipedia.org/wiki/Hydrogen
Khoj: {{"source": ["webpage"]}} Khoj: {{"source": ["webpage"], "output": ["text"]}}
Example: Example:
Chat History: Chat History:
User: I want to start a new hobby. I'm thinking of learning to play the guitar. User: I want to start a new hobby. I'm thinking of learning to play the guitar.
AI: Learning to play the guitar is a great hobby. It can be a lot of fun and a great way to express yourself. AI: Learning to play the guitar is a great hobby. It can be a lot of fun and a great way to express yourself.
Q: What is the first element of the periodic table? Q: Draw a painting of a guitar.
Khoj: {{"source": ["general"]}} Khoj: {{"source": ["general"], "output": ["image"]}}
Now it's your turn to pick the data sources you would like to use to answer the user's question. Provide the data sources as a list of strings in a JSON object. Do not say anything else. Now it's your turn to pick the data sources you would like to use to answer the user's question. Provide the data sources as a list of strings in a JSON object. Do not say anything else.

View file

@ -46,8 +46,7 @@ from khoj.routers.helpers import (
FeedbackData, FeedbackData,
acreate_title_from_history, acreate_title_from_history,
agenerate_chat_response, agenerate_chat_response,
aget_relevant_information_sources, aget_relevant_tools_to_execute,
aget_relevant_output_modes,
construct_automation_created_message, construct_automation_created_message,
create_automation, create_automation,
gather_raw_query_files, gather_raw_query_files,
@ -753,7 +752,7 @@ async def chat(
attached_file_context = gather_raw_query_files(query_files) attached_file_context = gather_raw_query_files(query_files)
if conversation_commands == [ConversationCommand.Default] or is_automated_task: if conversation_commands == [ConversationCommand.Default] or is_automated_task:
conversation_commands = await aget_relevant_information_sources( conversation_commands = await aget_relevant_tools_to_execute(
q, q,
meta_log, meta_log,
is_automated_task, is_automated_task,
@ -769,19 +768,9 @@ async def chat(
conversation_commands = [ConversationCommand.Research] conversation_commands = [ConversationCommand.Research]
conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands]) conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands])
async for result in send_event( async for result in send_event(ChatEvent.STATUS, f"**Selected Tools:** {conversation_commands_str}"):
ChatEvent.STATUS, f"**Chose Data Sources to Search:** {conversation_commands_str}"
):
yield result yield result
mode = await aget_relevant_output_modes(
q, meta_log, is_automated_task, user, uploaded_images, agent, tracer=tracer
)
async for result in send_event(ChatEvent.STATUS, f"**Decided Response Mode:** {mode.value}"):
yield result
if mode not in conversation_commands:
conversation_commands.append(mode)
for cmd in conversation_commands: for cmd in conversation_commands:
try: try:
await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd) await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd)
@ -1175,8 +1164,27 @@ async def chat(
inferred_queries.append(better_diagram_description_prompt) inferred_queries.append(better_diagram_description_prompt)
diagram_description = excalidraw_diagram_description diagram_description = excalidraw_diagram_description
else: else:
async for result in send_llm_response(f"Failed to generate diagram. Please try again later."): error_message = "Failed to generate diagram. Please try again later."
async for result in send_llm_response(error_message):
yield result yield result
await sync_to_async(save_to_conversation_log)(
q,
error_message,
user,
meta_log,
user_message_time,
inferred_queries=[better_diagram_description_prompt],
client_application=request.user.client_app,
conversation_id=conversation_id,
compiled_references=compiled_references,
online_results=online_results,
code_results=code_results,
query_images=uploaded_images,
train_of_thought=train_of_thought,
raw_query_files=raw_query_files,
tracer=tracer,
)
return return
content_obj = { content_obj = {

View file

@ -336,7 +336,7 @@ async def acheck_if_safe_prompt(system_prompt: str, user: KhojUser = None, lax:
return is_safe, reason return is_safe, reason
async def aget_relevant_information_sources( async def aget_relevant_tools_to_execute(
query: str, query: str,
conversation_history: dict, conversation_history: dict,
is_task: bool, is_task: bool,
@ -360,75 +360,6 @@ async def aget_relevant_information_sources(
if len(agent_tools) == 0 or tool.value in agent_tools: if len(agent_tools) == 0 or tool.value in agent_tools:
tool_options_str += f'- "{tool.value}": "{description}"\n' tool_options_str += f'- "{tool.value}": "{description}"\n'
chat_history = construct_chat_history(conversation_history)
if query_images:
query = f"[placeholder for {len(query_images)} user attached images]\n{query}"
personality_context = (
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
)
relevant_tools_prompt = prompts.pick_relevant_information_collection_tools.format(
query=query,
tools=tool_options_str,
chat_history=chat_history,
personality_context=personality_context,
)
with timer("Chat actor: Infer information sources to refer", logger):
response = await send_message_to_model_wrapper(
relevant_tools_prompt,
response_type="json_object",
user=user,
query_files=query_files,
tracer=tracer,
)
try:
response = clean_json(response)
response = json.loads(response)
response = [q.strip() for q in response["source"] if q.strip()]
if not isinstance(response, list) or not response or len(response) == 0:
logger.error(f"Invalid response for determining relevant tools: {response}")
return tool_options
final_response = [] if not is_task else [ConversationCommand.AutomatedTask]
for llm_suggested_tool in response:
# Add a double check to verify it's in the agent list, because the LLM sometimes gets confused by the tool options.
if llm_suggested_tool in tool_options.keys() and (
len(agent_tools) == 0 or llm_suggested_tool in agent_tools
):
# Check whether the tool exists as a valid ConversationCommand
final_response.append(ConversationCommand(llm_suggested_tool))
if is_none_or_empty(final_response):
if len(agent_tools) == 0:
final_response = [ConversationCommand.Default]
else:
final_response = [ConversationCommand.General]
except Exception:
logger.error(f"Invalid response for determining relevant tools: {response}")
if len(agent_tools) == 0:
final_response = [ConversationCommand.Default]
else:
final_response = agent_tools
return final_response
async def aget_relevant_output_modes(
query: str,
conversation_history: dict,
is_task: bool = False,
user: KhojUser = None,
query_images: List[str] = None,
agent: Agent = None,
tracer: dict = {},
):
"""
Given a query, determine which of the available tools the agent should use in order to answer appropriately.
"""
mode_options = dict() mode_options = dict()
mode_options_str = "" mode_options_str = ""
@ -451,37 +382,65 @@ async def aget_relevant_output_modes(
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else "" prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
) )
relevant_mode_prompt = prompts.pick_relevant_output_mode.format( relevant_tools_prompt = prompts.pick_relevant_tools.format(
query=query, query=query,
modes=mode_options_str, tools=tool_options_str,
outputs=mode_options_str,
chat_history=chat_history, chat_history=chat_history,
personality_context=personality_context, personality_context=personality_context,
) )
with timer("Chat actor: Infer output mode for chat response", logger): with timer("Chat actor: Infer information sources to refer", logger):
response = await send_message_to_model_wrapper( response = await send_message_to_model_wrapper(
relevant_mode_prompt, response_type="json_object", user=user, tracer=tracer relevant_tools_prompt,
response_type="json_object",
user=user,
query_files=query_files,
tracer=tracer,
) )
try: try:
response = clean_json(response) response = clean_json(response)
response = json.loads(response) response = json.loads(response)
input_tools = [q.strip() for q in response["source"] if q.strip()]
if not isinstance(input_tools, list) or not input_tools or len(input_tools) == 0:
logger.error(f"Invalid response for determining relevant tools: {input_tools}")
return tool_options
if is_none_or_empty(response): output_modes = [q.strip() for q in response["output"] if q.strip()]
return ConversationCommand.Text if not isinstance(output_modes, list) or not output_modes or len(output_modes) == 0:
logger.error(f"Invalid response for determining relevant output modes: {output_modes}")
return mode_options
output_mode = response["output"] final_response = [] if not is_task else [ConversationCommand.AutomatedTask]
for llm_suggested_tool in input_tools:
# Add a double check to verify it's in the agent list, because the LLM sometimes gets confused by the tool options.
if llm_suggested_tool in tool_options.keys() and (
len(agent_tools) == 0 or llm_suggested_tool in agent_tools
):
# Check whether the tool exists as a valid ConversationCommand
final_response.append(ConversationCommand(llm_suggested_tool))
# Add a double check to verify it's in the agent list, because the LLM sometimes gets confused by the tool options. for llm_suggested_output in output_modes:
if output_mode in mode_options.keys() and (len(output_modes) == 0 or output_mode in output_modes): # Add a double check to verify it's in the agent list, because the LLM sometimes gets confused by the tool options.
# Check whether the tool exists as a valid ConversationCommand if llm_suggested_output in mode_options.keys() and (
return ConversationCommand(output_mode) len(output_modes) == 0 or llm_suggested_output in output_modes
):
# Check whether the tool exists as a valid ConversationCommand
final_response.append(ConversationCommand(llm_suggested_output))
logger.error(f"Invalid output mode selected: {output_mode}. Defaulting to text.") if is_none_or_empty(final_response):
return ConversationCommand.Text if len(agent_tools) == 0:
final_response = [ConversationCommand.Default, ConversationCommand.Text]
else:
final_response = [ConversationCommand.General, ConversationCommand.Text]
except Exception: except Exception:
logger.error(f"Invalid response for determining output mode: {response}") logger.error(f"Invalid response for determining relevant tools: {response}")
return ConversationCommand.Text if len(agent_tools) == 0:
final_response = [ConversationCommand.Default, ConversationCommand.Text]
else:
final_response = agent_tools
return final_response
async def infer_webpage_urls( async def infer_webpage_urls(

View file

@ -365,7 +365,7 @@ tool_descriptions_for_llm = {
ConversationCommand.Notes: "To search the user's personal knowledge base. Especially helpful if the question expects context from the user's notes or documents.", ConversationCommand.Notes: "To search the user's personal knowledge base. Especially helpful if the question expects context from the user's notes or documents.",
ConversationCommand.Online: "To search for the latest, up-to-date information from the internet. Note: **Questions about Khoj should always use this data source**", ConversationCommand.Online: "To search for the latest, up-to-date information from the internet. Note: **Questions about Khoj should always use this data source**",
ConversationCommand.Webpage: "To use if the user has directly provided the webpage urls or you are certain of the webpage urls to read.", ConversationCommand.Webpage: "To use if the user has directly provided the webpage urls or you are certain of the webpage urls to read.",
ConversationCommand.Code: "To run Python code in a Pyodide sandbox with no network access. Helpful when need to parse information, run complex calculations, create documents and charts for user. Matplotlib, bs4, pandas, numpy, etc. are available.", ConversationCommand.Code: "To run Python code in a Pyodide sandbox with no network access. Helpful when need to parse information, run complex calculations, create plaintext documents, and create charts with quantitative data. Matplotlib, bs4, pandas, numpy, etc. are available.",
ConversationCommand.Summarize: "To retrieve an answer that depends on the entire document or a large text.", ConversationCommand.Summarize: "To retrieve an answer that depends on the entire document or a large text.",
} }
@ -373,14 +373,13 @@ function_calling_description_for_llm = {
ConversationCommand.Notes: "To search the user's personal knowledge base. Especially helpful if the question expects context from the user's notes or documents.", ConversationCommand.Notes: "To search the user's personal knowledge base. Especially helpful if the question expects context from the user's notes or documents.",
ConversationCommand.Online: "To search the internet for information. Useful to get a quick, broad overview from the internet. Provide all relevant context to ensure new searches, not in previous iterations, are performed.", ConversationCommand.Online: "To search the internet for information. Useful to get a quick, broad overview from the internet. Provide all relevant context to ensure new searches, not in previous iterations, are performed.",
ConversationCommand.Webpage: "To extract information from webpages. Useful for more detailed research from the internet. Usually used when you know the webpage links to refer to. Share the webpage links and information to extract in your query.", ConversationCommand.Webpage: "To extract information from webpages. Useful for more detailed research from the internet. Usually used when you know the webpage links to refer to. Share the webpage links and information to extract in your query.",
ConversationCommand.Code: "To run Python code in a Pyodide sandbox with no network access. Helpful when need to parse information, run complex calculations, create charts for user. Matplotlib, bs4, pandas, numpy, etc. are available.", ConversationCommand.Code: "To run Python code in a Pyodide sandbox with no network access. Helpful when need to parse information, run complex calculations, create plaintext documents, and create charts with quantitative data. Matplotlib, bs4, pandas, numpy, etc. are available.",
} }
mode_descriptions_for_llm = { mode_descriptions_for_llm = {
ConversationCommand.Image: "Use this if you are confident the user is requesting you to create a new picture based on their description. This does not support generating charts or graphs.", ConversationCommand.Image: "Use this if you are confident the user is requesting you to create a new picture based on their description. This DOES NOT support generating charts or graphs. It is for creative images.",
ConversationCommand.Automation: "Use this if you are confident the user is requesting a response at a scheduled date, time and frequency", ConversationCommand.Text: "Use this if a normal text response would be sufficient for accurately responding to the query or you don't feel strongly about the other modes.",
ConversationCommand.Text: "Use this if a normal text response would be sufficient for accurately responding to the query.", ConversationCommand.Diagram: "Use this if the user is requesting a diagram or visual representation that requires primitives like lines, rectangles, and text. This does not work for charts, graphs, or quantitative data. It is for mind mapping, flowcharts, etc.",
ConversationCommand.Diagram: "Use this if the user is requesting a diagram or visual representation that requires primitives like lines, rectangles, and text.",
} }
mode_descriptions_for_agent = { mode_descriptions_for_agent = {

View file

@ -18,7 +18,6 @@ from khoj.processor.conversation.offline.chat_model import (
) )
from khoj.processor.conversation.offline.utils import download_model from khoj.processor.conversation.offline.utils import download_model
from khoj.processor.conversation.utils import message_to_log from khoj.processor.conversation.utils import message_to_log
from khoj.routers.helpers import aget_relevant_output_modes
from khoj.utils.constants import default_offline_chat_models from khoj.utils.constants import default_offline_chat_models
@ -549,34 +548,6 @@ def test_filter_questions():
assert filtered_questions[0] == "Who is on the basketball team?" assert filtered_questions[0] == "Who is on the basketball team?"
# ----------------------------------------------------------------------------------------------------
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
async def test_use_text_response_mode(client_offline_chat):
# Arrange
user_query = "What's the latest in the Israel/Palestine conflict?"
# Act
mode = await aget_relevant_output_modes(user_query, {})
# Assert
assert mode.value == "text"
# ----------------------------------------------------------------------------------------------------
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
async def test_use_image_response_mode(client_offline_chat):
# Arrange
user_query = "Paint a picture of the scenery in Timbuktu in the winter"
# Act
mode = await aget_relevant_output_modes(user_query, {})
# Assert
assert mode.value == "image"
# Helpers # Helpers
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
def populate_chat_history(message_list): def populate_chat_history(message_list):

View file

@ -7,7 +7,7 @@ from freezegun import freeze_time
from khoj.database.models import Agent, Entry, KhojUser from khoj.database.models import Agent, Entry, KhojUser
from khoj.processor.conversation import prompts from khoj.processor.conversation import prompts
from khoj.processor.conversation.utils import message_to_log from khoj.processor.conversation.utils import message_to_log
from khoj.routers.helpers import aget_relevant_information_sources from khoj.routers.helpers import aget_relevant_tools_to_execute
from tests.helpers import ConversationFactory from tests.helpers import ConversationFactory
SKIP_TESTS = True SKIP_TESTS = True
@ -735,7 +735,7 @@ async def test_get_correct_tools_online(client_offline_chat):
user_query = "What's the weather in Patagonia this week?" user_query = "What's the weather in Patagonia this week?"
# Act # Act
tools = await aget_relevant_information_sources(user_query, {}, is_task=False) tools = await aget_relevant_tools_to_execute(user_query, {}, is_task=False)
# Assert # Assert
tools = [tool.value for tool in tools] tools = [tool.value for tool in tools]
@ -750,7 +750,7 @@ async def test_get_correct_tools_notes(client_offline_chat):
user_query = "Where did I go for my first battleship training?" user_query = "Where did I go for my first battleship training?"
# Act # Act
tools = await aget_relevant_information_sources(user_query, {}, is_task=False) tools = await aget_relevant_tools_to_execute(user_query, {}, is_task=False)
# Assert # Assert
tools = [tool.value for tool in tools] tools = [tool.value for tool in tools]
@ -765,7 +765,7 @@ async def test_get_correct_tools_online_or_general_and_notes(client_offline_chat
user_query = "What's the highest point in Patagonia and have I been there?" user_query = "What's the highest point in Patagonia and have I been there?"
# Act # Act
tools = await aget_relevant_information_sources(user_query, {}, is_task=False) tools = await aget_relevant_tools_to_execute(user_query, {}, is_task=False)
# Assert # Assert
tools = [tool.value for tool in tools] tools = [tool.value for tool in tools]
@ -782,7 +782,7 @@ async def test_get_correct_tools_general(client_offline_chat):
user_query = "How many noble gases are there?" user_query = "How many noble gases are there?"
# Act # Act
tools = await aget_relevant_information_sources(user_query, {}, is_task=False) tools = await aget_relevant_tools_to_execute(user_query, {}, is_task=False)
# Assert # Assert
tools = [tool.value for tool in tools] tools = [tool.value for tool in tools]
@ -806,7 +806,7 @@ async def test_get_correct_tools_with_chat_history(client_offline_chat, default_
chat_history = create_conversation(chat_log, default_user2) chat_history = create_conversation(chat_log, default_user2)
# Act # Act
tools = await aget_relevant_information_sources(user_query, chat_history, is_task=False) tools = await aget_relevant_tools_to_execute(user_query, chat_history, is_task=False)
# Assert # Assert
tools = [tool.value for tool in tools] tools = [tool.value for tool in tools]

View file

@ -8,8 +8,7 @@ from freezegun import freeze_time
from khoj.processor.conversation.openai.gpt import converse, extract_questions from khoj.processor.conversation.openai.gpt import converse, extract_questions
from khoj.processor.conversation.utils import message_to_log from khoj.processor.conversation.utils import message_to_log
from khoj.routers.helpers import ( from khoj.routers.helpers import (
aget_relevant_information_sources, aget_relevant_tools_to_execute,
aget_relevant_output_modes,
generate_online_subqueries, generate_online_subqueries,
infer_webpage_urls, infer_webpage_urls,
schedule_query, schedule_query,
@ -524,26 +523,6 @@ async def test_websearch_khoj_website_for_info_about_khoj(chat_client, default_u
), "Expected search query to include site:khoj.dev but got: " + str(responses) ), "Expected search query to include site:khoj.dev but got: " + str(responses)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
@pytest.mark.parametrize(
"user_query, expected_mode",
[
("What's the latest in the Israel/Palestine conflict?", "text"),
("Summarize the latest tech news every Monday evening", "automation"),
("Paint a scenery in Timbuktu in the winter", "image"),
("Remind me, when did I last visit the Serengeti?", "text"),
],
)
async def test_use_default_response_mode(chat_client, user_query, expected_mode):
# Act
mode = await aget_relevant_output_modes(user_query, {})
# Assert
assert mode.value == expected_mode
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
@pytest.mark.anyio @pytest.mark.anyio
@pytest.mark.django_db(transaction=True) @pytest.mark.django_db(transaction=True)
@ -559,7 +538,7 @@ async def test_select_data_sources_actor_chooses_to_search_notes(
chat_client, user_query, expected_conversation_commands chat_client, user_query, expected_conversation_commands
): ):
# Act # Act
conversation_commands = await aget_relevant_information_sources(user_query, {}, False, False) conversation_commands = await aget_relevant_tools_to_execute(user_query, {}, False, False)
# Assert # Assert
assert set(expected_conversation_commands) == set(conversation_commands) assert set(expected_conversation_commands) == set(conversation_commands)

View file

@ -8,7 +8,7 @@ from freezegun import freeze_time
from khoj.database.models import Agent, Entry, KhojUser, LocalPdfConfig from khoj.database.models import Agent, Entry, KhojUser, LocalPdfConfig
from khoj.processor.conversation import prompts from khoj.processor.conversation import prompts
from khoj.processor.conversation.utils import message_to_log from khoj.processor.conversation.utils import message_to_log
from khoj.routers.helpers import aget_relevant_information_sources from khoj.routers.helpers import aget_relevant_tools_to_execute
from tests.helpers import ConversationFactory from tests.helpers import ConversationFactory
# Initialize variables for tests # Initialize variables for tests
@ -719,7 +719,7 @@ async def test_get_correct_tools_online(chat_client):
user_query = "What's the weather in Patagonia this week?" user_query = "What's the weather in Patagonia this week?"
# Act # Act
tools = await aget_relevant_information_sources(user_query, {}, False, False) tools = await aget_relevant_tools_to_execute(user_query, {}, False, False)
# Assert # Assert
tools = [tool.value for tool in tools] tools = [tool.value for tool in tools]
@ -734,7 +734,7 @@ async def test_get_correct_tools_notes(chat_client):
user_query = "Where did I go for my first battleship training?" user_query = "Where did I go for my first battleship training?"
# Act # Act
tools = await aget_relevant_information_sources(user_query, {}, False, False) tools = await aget_relevant_tools_to_execute(user_query, {}, False, False)
# Assert # Assert
tools = [tool.value for tool in tools] tools = [tool.value for tool in tools]
@ -749,7 +749,7 @@ async def test_get_correct_tools_online_or_general_and_notes(chat_client):
user_query = "What's the highest point in Patagonia and have I been there?" user_query = "What's the highest point in Patagonia and have I been there?"
# Act # Act
tools = await aget_relevant_information_sources(user_query, {}, False, False) tools = await aget_relevant_tools_to_execute(user_query, {}, False, False)
# Assert # Assert
tools = [tool.value for tool in tools] tools = [tool.value for tool in tools]
@ -766,7 +766,7 @@ async def test_get_correct_tools_general(chat_client):
user_query = "How many noble gases are there?" user_query = "How many noble gases are there?"
# Act # Act
tools = await aget_relevant_information_sources(user_query, {}, False, False) tools = await aget_relevant_tools_to_execute(user_query, {}, False, False)
# Assert # Assert
tools = [tool.value for tool in tools] tools = [tool.value for tool in tools]
@ -790,7 +790,7 @@ async def test_get_correct_tools_with_chat_history(chat_client):
chat_history = generate_history(chat_log) chat_history = generate_history(chat_log)
# Act # Act
tools = await aget_relevant_information_sources(user_query, chat_history, False, False) tools = await aget_relevant_tools_to_execute(user_query, chat_history, False, False)
# Assert # Assert
tools = [tool.value for tool in tools] tools = [tool.value for tool in tools]