mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 15:38:55 +01:00
Merge pull request #977 from khoj-ai/features/improve-tool-selection
- JSON extract from LLMs is pretty decent now, so get the input tools and output modes all in one go. It'll help the model think through the full cycle of what it wants to do to handle the request holistically. - Make slight improvements to tool selection indicators
This commit is contained in:
commit
7d50c6590d
10 changed files with 104 additions and 242 deletions
|
@ -14,7 +14,7 @@ Try it out yourself! https://app.khoj.dev
|
|||
|
||||
## Self-Hosting
|
||||
|
||||
Online search works out of the box even when self-hosting. Khoj uses [JinaAI's reader API](https://jina.ai/reader/) to search online and read webpages by default. No API key setup is necessary.
|
||||
Online search can work even with self-hosting! Khoj uses [JinaAI's reader API](https://jina.ai/reader/) to search online and read webpages by default. You can get a free API key via https://jina.ai/reader. Set the `JINA_API_KEY` environment variable to your Jina AI reader API key to enable online search.
|
||||
|
||||
To improve online search, set the `SERPER_DEV_API_KEY` environment variable to your [Serper.dev](https://serper.dev/) API key. These search results include additional context like answer box, knowledge graph etc.
|
||||
|
||||
|
|
|
@ -30,6 +30,7 @@ import {
|
|||
Code,
|
||||
Shapes,
|
||||
Trash,
|
||||
Toolbox,
|
||||
} from "@phosphor-icons/react";
|
||||
|
||||
import DOMPurify from "dompurify";
|
||||
|
@ -282,8 +283,8 @@ function chooseIconFromHeader(header: string, iconColor: string) {
|
|||
return <Cloud className={`${classNames}`} />;
|
||||
}
|
||||
|
||||
if (compareHeader.includes("data sources")) {
|
||||
return <Folder className={`${classNames}`} />;
|
||||
if (compareHeader.includes("tools")) {
|
||||
return <Toolbox className={`${classNames}`} />;
|
||||
}
|
||||
|
||||
if (compareHeader.includes("notes")) {
|
||||
|
|
|
@ -565,59 +565,6 @@ Here's some additional context about you:
|
|||
"""
|
||||
)
|
||||
|
||||
pick_relevant_output_mode = PromptTemplate.from_template(
|
||||
"""
|
||||
You are Khoj, an excellent analyst for selecting the correct way to respond to a user's query.
|
||||
{personality_context}
|
||||
You have access to a limited set of modes for your response.
|
||||
You can only use one of these modes.
|
||||
|
||||
{modes}
|
||||
|
||||
Here are some examples:
|
||||
|
||||
Example:
|
||||
Chat History:
|
||||
User: I just visited Jerusalem for the first time. Pull up my notes from the trip.
|
||||
AI: You mention visiting Masjid Al-Aqsa and the Western Wall. You also mention trying the local cuisine and visiting the Dead Sea.
|
||||
|
||||
Q: Draw a picture of my trip to Jerusalem.
|
||||
Khoj: {{"output": "image"}}
|
||||
|
||||
Example:
|
||||
Chat History:
|
||||
User: I'm having trouble deciding which laptop to get. I want something with at least 16 GB of RAM and a 1 TB SSD.
|
||||
AI: I can help with that. I see online that there is a new model of the Dell XPS 15 that meets your requirements.
|
||||
|
||||
Q: What are the specs of the new Dell XPS 15?
|
||||
Khoj: {{"output": "text"}}
|
||||
|
||||
Example:
|
||||
Chat History:
|
||||
User: Where did I go on my last vacation?
|
||||
AI: You went to Jordan and visited Petra, the Dead Sea, and Wadi Rum.
|
||||
|
||||
Q: Remind me who did I go with on that trip?
|
||||
Khoj: {{"output": "text"}}
|
||||
|
||||
Example:
|
||||
Chat History:
|
||||
User: How's the weather outside? Current Location: Bali, Indonesia
|
||||
AI: It's currently 28°C and partly cloudy in Bali.
|
||||
|
||||
Q: Share a painting using the weather for Bali every morning.
|
||||
Khoj: {{"output": "automation"}}
|
||||
|
||||
Now it's your turn to pick the mode you would like to use to answer the user's question. Provide your response as a JSON. Do not say anything else.
|
||||
|
||||
Chat History:
|
||||
{chat_history}
|
||||
|
||||
Q: {query}
|
||||
Khoj:
|
||||
""".strip()
|
||||
)
|
||||
|
||||
plan_function_execution = PromptTemplate.from_template(
|
||||
"""
|
||||
You are Khoj, a smart, creative and methodical researcher. Use the provided tool AIs to investigate information to answer query.
|
||||
|
@ -679,18 +626,23 @@ previous_iteration = PromptTemplate.from_template(
|
|||
"""
|
||||
)
|
||||
|
||||
pick_relevant_information_collection_tools = PromptTemplate.from_template(
|
||||
pick_relevant_tools = PromptTemplate.from_template(
|
||||
"""
|
||||
You are Khoj, an extremely smart and helpful search assistant.
|
||||
{personality_context}
|
||||
- You have access to a variety of data sources to help you answer the user's question
|
||||
- You can use the data sources listed below to collect more relevant information
|
||||
- You can use any combination of these data sources to answer the user's question
|
||||
- You can select certain types of output to respond to the user's question. Select just one output type to answer the user's question
|
||||
- You can use any combination of these data sources and output types to answer the user's question
|
||||
|
||||
Which of the data sources listed below you would use to answer the user's question? You **only** have access to the following data sources:
|
||||
Which of the tools listed below you would use to answer the user's question? You **only** have access to the following:
|
||||
|
||||
Inputs:
|
||||
{tools}
|
||||
|
||||
Outputs:
|
||||
{outputs}
|
||||
|
||||
Here are some examples:
|
||||
|
||||
Example:
|
||||
|
@ -699,7 +651,7 @@ User: I'm thinking of moving to a new city. I'm trying to decide between New Yor
|
|||
AI: Moving to a new city can be challenging. Both New York and San Francisco are great cities to live in. New York is known for its diverse culture and San Francisco is known for its tech scene.
|
||||
|
||||
Q: What is the population of each of those cities?
|
||||
Khoj: {{"source": ["online"]}}
|
||||
Khoj: {{"source": ["online"], "output": ["text"]}}
|
||||
|
||||
Example:
|
||||
Chat History:
|
||||
|
@ -707,14 +659,7 @@ User: I'm thinking of my next vacation idea. Ideally, I want to see something ne
|
|||
AI: Excellent! Taking a vacation is a great way to relax and recharge.
|
||||
|
||||
Q: Where did Grandma grow up?
|
||||
Khoj: {{"source": ["notes"]}}
|
||||
|
||||
Example:
|
||||
Chat History:
|
||||
|
||||
|
||||
Q: What can you do for me?
|
||||
Khoj: {{"source": ["notes", "online"]}}
|
||||
Khoj: {{"source": ["notes"], "output": ["text"]}}
|
||||
|
||||
Example:
|
||||
Chat History:
|
||||
|
@ -722,7 +667,7 @@ User: Good morning
|
|||
AI: Good morning! How can I help you today?
|
||||
|
||||
Q: How can I share my files with Khoj?
|
||||
Khoj: {{"source": ["default", "online"]}}
|
||||
Khoj: {{"source": ["default", "online"], "output": ["text"]}}
|
||||
|
||||
Example:
|
||||
Chat History:
|
||||
|
@ -730,15 +675,15 @@ User: What is the first element in the periodic table?
|
|||
AI: The first element in the periodic table is Hydrogen.
|
||||
|
||||
Q: Summarize this article https://en.wikipedia.org/wiki/Hydrogen
|
||||
Khoj: {{"source": ["webpage"]}}
|
||||
Khoj: {{"source": ["webpage"], "output": ["text"]}}
|
||||
|
||||
Example:
|
||||
Chat History:
|
||||
User: I want to start a new hobby. I'm thinking of learning to play the guitar.
|
||||
AI: Learning to play the guitar is a great hobby. It can be a lot of fun and a great way to express yourself.
|
||||
|
||||
Q: What is the first element of the periodic table?
|
||||
Khoj: {{"source": ["general"]}}
|
||||
Q: Draw a painting of a guitar.
|
||||
Khoj: {{"source": ["general"], "output": ["image"]}}
|
||||
|
||||
Now it's your turn to pick the data sources you would like to use to answer the user's question. Provide the data sources as a list of strings in a JSON object. Do not say anything else.
|
||||
|
||||
|
|
|
@ -46,8 +46,7 @@ from khoj.routers.helpers import (
|
|||
FeedbackData,
|
||||
acreate_title_from_history,
|
||||
agenerate_chat_response,
|
||||
aget_relevant_information_sources,
|
||||
aget_relevant_output_modes,
|
||||
aget_relevant_tools_to_execute,
|
||||
construct_automation_created_message,
|
||||
create_automation,
|
||||
gather_raw_query_files,
|
||||
|
@ -753,7 +752,7 @@ async def chat(
|
|||
attached_file_context = gather_raw_query_files(query_files)
|
||||
|
||||
if conversation_commands == [ConversationCommand.Default] or is_automated_task:
|
||||
conversation_commands = await aget_relevant_information_sources(
|
||||
conversation_commands = await aget_relevant_tools_to_execute(
|
||||
q,
|
||||
meta_log,
|
||||
is_automated_task,
|
||||
|
@ -769,19 +768,9 @@ async def chat(
|
|||
conversation_commands = [ConversationCommand.Research]
|
||||
|
||||
conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands])
|
||||
async for result in send_event(
|
||||
ChatEvent.STATUS, f"**Chose Data Sources to Search:** {conversation_commands_str}"
|
||||
):
|
||||
async for result in send_event(ChatEvent.STATUS, f"**Selected Tools:** {conversation_commands_str}"):
|
||||
yield result
|
||||
|
||||
mode = await aget_relevant_output_modes(
|
||||
q, meta_log, is_automated_task, user, uploaded_images, agent, tracer=tracer
|
||||
)
|
||||
async for result in send_event(ChatEvent.STATUS, f"**Decided Response Mode:** {mode.value}"):
|
||||
yield result
|
||||
if mode not in conversation_commands:
|
||||
conversation_commands.append(mode)
|
||||
|
||||
for cmd in conversation_commands:
|
||||
try:
|
||||
await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd)
|
||||
|
@ -1175,8 +1164,27 @@ async def chat(
|
|||
inferred_queries.append(better_diagram_description_prompt)
|
||||
diagram_description = excalidraw_diagram_description
|
||||
else:
|
||||
async for result in send_llm_response(f"Failed to generate diagram. Please try again later."):
|
||||
error_message = "Failed to generate diagram. Please try again later."
|
||||
async for result in send_llm_response(error_message):
|
||||
yield result
|
||||
|
||||
await sync_to_async(save_to_conversation_log)(
|
||||
q,
|
||||
error_message,
|
||||
user,
|
||||
meta_log,
|
||||
user_message_time,
|
||||
inferred_queries=[better_diagram_description_prompt],
|
||||
client_application=request.user.client_app,
|
||||
conversation_id=conversation_id,
|
||||
compiled_references=compiled_references,
|
||||
online_results=online_results,
|
||||
code_results=code_results,
|
||||
query_images=uploaded_images,
|
||||
train_of_thought=train_of_thought,
|
||||
raw_query_files=raw_query_files,
|
||||
tracer=tracer,
|
||||
)
|
||||
return
|
||||
|
||||
content_obj = {
|
||||
|
|
|
@ -336,7 +336,7 @@ async def acheck_if_safe_prompt(system_prompt: str, user: KhojUser = None, lax:
|
|||
return is_safe, reason
|
||||
|
||||
|
||||
async def aget_relevant_information_sources(
|
||||
async def aget_relevant_tools_to_execute(
|
||||
query: str,
|
||||
conversation_history: dict,
|
||||
is_task: bool,
|
||||
|
@ -360,75 +360,6 @@ async def aget_relevant_information_sources(
|
|||
if len(agent_tools) == 0 or tool.value in agent_tools:
|
||||
tool_options_str += f'- "{tool.value}": "{description}"\n'
|
||||
|
||||
chat_history = construct_chat_history(conversation_history)
|
||||
|
||||
if query_images:
|
||||
query = f"[placeholder for {len(query_images)} user attached images]\n{query}"
|
||||
|
||||
personality_context = (
|
||||
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
|
||||
)
|
||||
|
||||
relevant_tools_prompt = prompts.pick_relevant_information_collection_tools.format(
|
||||
query=query,
|
||||
tools=tool_options_str,
|
||||
chat_history=chat_history,
|
||||
personality_context=personality_context,
|
||||
)
|
||||
|
||||
with timer("Chat actor: Infer information sources to refer", logger):
|
||||
response = await send_message_to_model_wrapper(
|
||||
relevant_tools_prompt,
|
||||
response_type="json_object",
|
||||
user=user,
|
||||
query_files=query_files,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
try:
|
||||
response = clean_json(response)
|
||||
response = json.loads(response)
|
||||
response = [q.strip() for q in response["source"] if q.strip()]
|
||||
if not isinstance(response, list) or not response or len(response) == 0:
|
||||
logger.error(f"Invalid response for determining relevant tools: {response}")
|
||||
return tool_options
|
||||
|
||||
final_response = [] if not is_task else [ConversationCommand.AutomatedTask]
|
||||
for llm_suggested_tool in response:
|
||||
# Add a double check to verify it's in the agent list, because the LLM sometimes gets confused by the tool options.
|
||||
if llm_suggested_tool in tool_options.keys() and (
|
||||
len(agent_tools) == 0 or llm_suggested_tool in agent_tools
|
||||
):
|
||||
# Check whether the tool exists as a valid ConversationCommand
|
||||
final_response.append(ConversationCommand(llm_suggested_tool))
|
||||
|
||||
if is_none_or_empty(final_response):
|
||||
if len(agent_tools) == 0:
|
||||
final_response = [ConversationCommand.Default]
|
||||
else:
|
||||
final_response = [ConversationCommand.General]
|
||||
except Exception:
|
||||
logger.error(f"Invalid response for determining relevant tools: {response}")
|
||||
if len(agent_tools) == 0:
|
||||
final_response = [ConversationCommand.Default]
|
||||
else:
|
||||
final_response = agent_tools
|
||||
return final_response
|
||||
|
||||
|
||||
async def aget_relevant_output_modes(
|
||||
query: str,
|
||||
conversation_history: dict,
|
||||
is_task: bool = False,
|
||||
user: KhojUser = None,
|
||||
query_images: List[str] = None,
|
||||
agent: Agent = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
"""
|
||||
Given a query, determine which of the available tools the agent should use in order to answer appropriately.
|
||||
"""
|
||||
|
||||
mode_options = dict()
|
||||
mode_options_str = ""
|
||||
|
||||
|
@ -451,37 +382,65 @@ async def aget_relevant_output_modes(
|
|||
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
|
||||
)
|
||||
|
||||
relevant_mode_prompt = prompts.pick_relevant_output_mode.format(
|
||||
relevant_tools_prompt = prompts.pick_relevant_tools.format(
|
||||
query=query,
|
||||
modes=mode_options_str,
|
||||
tools=tool_options_str,
|
||||
outputs=mode_options_str,
|
||||
chat_history=chat_history,
|
||||
personality_context=personality_context,
|
||||
)
|
||||
|
||||
with timer("Chat actor: Infer output mode for chat response", logger):
|
||||
with timer("Chat actor: Infer information sources to refer", logger):
|
||||
response = await send_message_to_model_wrapper(
|
||||
relevant_mode_prompt, response_type="json_object", user=user, tracer=tracer
|
||||
relevant_tools_prompt,
|
||||
response_type="json_object",
|
||||
user=user,
|
||||
query_files=query_files,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
try:
|
||||
response = clean_json(response)
|
||||
response = json.loads(response)
|
||||
input_tools = [q.strip() for q in response["source"] if q.strip()]
|
||||
if not isinstance(input_tools, list) or not input_tools or len(input_tools) == 0:
|
||||
logger.error(f"Invalid response for determining relevant tools: {input_tools}")
|
||||
return tool_options
|
||||
|
||||
if is_none_or_empty(response):
|
||||
return ConversationCommand.Text
|
||||
|
||||
output_mode = response["output"]
|
||||
output_modes = [q.strip() for q in response["output"] if q.strip()]
|
||||
if not isinstance(output_modes, list) or not output_modes or len(output_modes) == 0:
|
||||
logger.error(f"Invalid response for determining relevant output modes: {output_modes}")
|
||||
return mode_options
|
||||
|
||||
final_response = [] if not is_task else [ConversationCommand.AutomatedTask]
|
||||
for llm_suggested_tool in input_tools:
|
||||
# Add a double check to verify it's in the agent list, because the LLM sometimes gets confused by the tool options.
|
||||
if output_mode in mode_options.keys() and (len(output_modes) == 0 or output_mode in output_modes):
|
||||
if llm_suggested_tool in tool_options.keys() and (
|
||||
len(agent_tools) == 0 or llm_suggested_tool in agent_tools
|
||||
):
|
||||
# Check whether the tool exists as a valid ConversationCommand
|
||||
return ConversationCommand(output_mode)
|
||||
final_response.append(ConversationCommand(llm_suggested_tool))
|
||||
|
||||
logger.error(f"Invalid output mode selected: {output_mode}. Defaulting to text.")
|
||||
return ConversationCommand.Text
|
||||
for llm_suggested_output in output_modes:
|
||||
# Add a double check to verify it's in the agent list, because the LLM sometimes gets confused by the tool options.
|
||||
if llm_suggested_output in mode_options.keys() and (
|
||||
len(output_modes) == 0 or llm_suggested_output in output_modes
|
||||
):
|
||||
# Check whether the tool exists as a valid ConversationCommand
|
||||
final_response.append(ConversationCommand(llm_suggested_output))
|
||||
|
||||
if is_none_or_empty(final_response):
|
||||
if len(agent_tools) == 0:
|
||||
final_response = [ConversationCommand.Default, ConversationCommand.Text]
|
||||
else:
|
||||
final_response = [ConversationCommand.General, ConversationCommand.Text]
|
||||
except Exception:
|
||||
logger.error(f"Invalid response for determining output mode: {response}")
|
||||
return ConversationCommand.Text
|
||||
logger.error(f"Invalid response for determining relevant tools: {response}")
|
||||
if len(agent_tools) == 0:
|
||||
final_response = [ConversationCommand.Default, ConversationCommand.Text]
|
||||
else:
|
||||
final_response = agent_tools
|
||||
return final_response
|
||||
|
||||
|
||||
async def infer_webpage_urls(
|
||||
|
|
|
@ -365,7 +365,7 @@ tool_descriptions_for_llm = {
|
|||
ConversationCommand.Notes: "To search the user's personal knowledge base. Especially helpful if the question expects context from the user's notes or documents.",
|
||||
ConversationCommand.Online: "To search for the latest, up-to-date information from the internet. Note: **Questions about Khoj should always use this data source**",
|
||||
ConversationCommand.Webpage: "To use if the user has directly provided the webpage urls or you are certain of the webpage urls to read.",
|
||||
ConversationCommand.Code: "To run Python code in a Pyodide sandbox with no network access. Helpful when need to parse information, run complex calculations, create documents and charts for user. Matplotlib, bs4, pandas, numpy, etc. are available.",
|
||||
ConversationCommand.Code: "To run Python code in a Pyodide sandbox with no network access. Helpful when need to parse information, run complex calculations, create plaintext documents, and create charts with quantitative data. Matplotlib, bs4, pandas, numpy, etc. are available.",
|
||||
ConversationCommand.Summarize: "To retrieve an answer that depends on the entire document or a large text.",
|
||||
}
|
||||
|
||||
|
@ -373,14 +373,13 @@ function_calling_description_for_llm = {
|
|||
ConversationCommand.Notes: "To search the user's personal knowledge base. Especially helpful if the question expects context from the user's notes or documents.",
|
||||
ConversationCommand.Online: "To search the internet for information. Useful to get a quick, broad overview from the internet. Provide all relevant context to ensure new searches, not in previous iterations, are performed.",
|
||||
ConversationCommand.Webpage: "To extract information from webpages. Useful for more detailed research from the internet. Usually used when you know the webpage links to refer to. Share the webpage links and information to extract in your query.",
|
||||
ConversationCommand.Code: "To run Python code in a Pyodide sandbox with no network access. Helpful when need to parse information, run complex calculations, create charts for user. Matplotlib, bs4, pandas, numpy, etc. are available.",
|
||||
ConversationCommand.Code: "To run Python code in a Pyodide sandbox with no network access. Helpful when need to parse information, run complex calculations, create plaintext documents, and create charts with quantitative data. Matplotlib, bs4, pandas, numpy, etc. are available.",
|
||||
}
|
||||
|
||||
mode_descriptions_for_llm = {
|
||||
ConversationCommand.Image: "Use this if you are confident the user is requesting you to create a new picture based on their description. This does not support generating charts or graphs.",
|
||||
ConversationCommand.Automation: "Use this if you are confident the user is requesting a response at a scheduled date, time and frequency",
|
||||
ConversationCommand.Text: "Use this if a normal text response would be sufficient for accurately responding to the query.",
|
||||
ConversationCommand.Diagram: "Use this if the user is requesting a diagram or visual representation that requires primitives like lines, rectangles, and text.",
|
||||
ConversationCommand.Image: "Use this if you are confident the user is requesting you to create a new picture based on their description. This DOES NOT support generating charts or graphs. It is for creative images.",
|
||||
ConversationCommand.Text: "Use this if a normal text response would be sufficient for accurately responding to the query or you don't feel strongly about the other modes.",
|
||||
ConversationCommand.Diagram: "Use this if the user is requesting a diagram or visual representation that requires primitives like lines, rectangles, and text. This does not work for charts, graphs, or quantitative data. It is for mind mapping, flowcharts, etc.",
|
||||
}
|
||||
|
||||
mode_descriptions_for_agent = {
|
||||
|
|
|
@ -18,7 +18,6 @@ from khoj.processor.conversation.offline.chat_model import (
|
|||
)
|
||||
from khoj.processor.conversation.offline.utils import download_model
|
||||
from khoj.processor.conversation.utils import message_to_log
|
||||
from khoj.routers.helpers import aget_relevant_output_modes
|
||||
from khoj.utils.constants import default_offline_chat_models
|
||||
|
||||
|
||||
|
@ -549,34 +548,6 @@ def test_filter_questions():
|
|||
assert filtered_questions[0] == "Who is on the basketball team?"
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
async def test_use_text_response_mode(client_offline_chat):
|
||||
# Arrange
|
||||
user_query = "What's the latest in the Israel/Palestine conflict?"
|
||||
|
||||
# Act
|
||||
mode = await aget_relevant_output_modes(user_query, {})
|
||||
|
||||
# Assert
|
||||
assert mode.value == "text"
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
async def test_use_image_response_mode(client_offline_chat):
|
||||
# Arrange
|
||||
user_query = "Paint a picture of the scenery in Timbuktu in the winter"
|
||||
|
||||
# Act
|
||||
mode = await aget_relevant_output_modes(user_query, {})
|
||||
|
||||
# Assert
|
||||
assert mode.value == "image"
|
||||
|
||||
|
||||
# Helpers
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def populate_chat_history(message_list):
|
||||
|
|
|
@ -7,7 +7,7 @@ from freezegun import freeze_time
|
|||
from khoj.database.models import Agent, Entry, KhojUser
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.utils import message_to_log
|
||||
from khoj.routers.helpers import aget_relevant_information_sources
|
||||
from khoj.routers.helpers import aget_relevant_tools_to_execute
|
||||
from tests.helpers import ConversationFactory
|
||||
|
||||
SKIP_TESTS = True
|
||||
|
@ -735,7 +735,7 @@ async def test_get_correct_tools_online(client_offline_chat):
|
|||
user_query = "What's the weather in Patagonia this week?"
|
||||
|
||||
# Act
|
||||
tools = await aget_relevant_information_sources(user_query, {}, is_task=False)
|
||||
tools = await aget_relevant_tools_to_execute(user_query, {}, is_task=False)
|
||||
|
||||
# Assert
|
||||
tools = [tool.value for tool in tools]
|
||||
|
@ -750,7 +750,7 @@ async def test_get_correct_tools_notes(client_offline_chat):
|
|||
user_query = "Where did I go for my first battleship training?"
|
||||
|
||||
# Act
|
||||
tools = await aget_relevant_information_sources(user_query, {}, is_task=False)
|
||||
tools = await aget_relevant_tools_to_execute(user_query, {}, is_task=False)
|
||||
|
||||
# Assert
|
||||
tools = [tool.value for tool in tools]
|
||||
|
@ -765,7 +765,7 @@ async def test_get_correct_tools_online_or_general_and_notes(client_offline_chat
|
|||
user_query = "What's the highest point in Patagonia and have I been there?"
|
||||
|
||||
# Act
|
||||
tools = await aget_relevant_information_sources(user_query, {}, is_task=False)
|
||||
tools = await aget_relevant_tools_to_execute(user_query, {}, is_task=False)
|
||||
|
||||
# Assert
|
||||
tools = [tool.value for tool in tools]
|
||||
|
@ -782,7 +782,7 @@ async def test_get_correct_tools_general(client_offline_chat):
|
|||
user_query = "How many noble gases are there?"
|
||||
|
||||
# Act
|
||||
tools = await aget_relevant_information_sources(user_query, {}, is_task=False)
|
||||
tools = await aget_relevant_tools_to_execute(user_query, {}, is_task=False)
|
||||
|
||||
# Assert
|
||||
tools = [tool.value for tool in tools]
|
||||
|
@ -806,7 +806,7 @@ async def test_get_correct_tools_with_chat_history(client_offline_chat, default_
|
|||
chat_history = create_conversation(chat_log, default_user2)
|
||||
|
||||
# Act
|
||||
tools = await aget_relevant_information_sources(user_query, chat_history, is_task=False)
|
||||
tools = await aget_relevant_tools_to_execute(user_query, chat_history, is_task=False)
|
||||
|
||||
# Assert
|
||||
tools = [tool.value for tool in tools]
|
||||
|
|
|
@ -8,8 +8,7 @@ from freezegun import freeze_time
|
|||
from khoj.processor.conversation.openai.gpt import converse, extract_questions
|
||||
from khoj.processor.conversation.utils import message_to_log
|
||||
from khoj.routers.helpers import (
|
||||
aget_relevant_information_sources,
|
||||
aget_relevant_output_modes,
|
||||
aget_relevant_tools_to_execute,
|
||||
generate_online_subqueries,
|
||||
infer_webpage_urls,
|
||||
schedule_query,
|
||||
|
@ -524,26 +523,6 @@ async def test_websearch_khoj_website_for_info_about_khoj(chat_client, default_u
|
|||
), "Expected search query to include site:khoj.dev but got: " + str(responses)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
@pytest.mark.parametrize(
|
||||
"user_query, expected_mode",
|
||||
[
|
||||
("What's the latest in the Israel/Palestine conflict?", "text"),
|
||||
("Summarize the latest tech news every Monday evening", "automation"),
|
||||
("Paint a scenery in Timbuktu in the winter", "image"),
|
||||
("Remind me, when did I last visit the Serengeti?", "text"),
|
||||
],
|
||||
)
|
||||
async def test_use_default_response_mode(chat_client, user_query, expected_mode):
|
||||
# Act
|
||||
mode = await aget_relevant_output_modes(user_query, {})
|
||||
|
||||
# Assert
|
||||
assert mode.value == expected_mode
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
|
@ -559,7 +538,7 @@ async def test_select_data_sources_actor_chooses_to_search_notes(
|
|||
chat_client, user_query, expected_conversation_commands
|
||||
):
|
||||
# Act
|
||||
conversation_commands = await aget_relevant_information_sources(user_query, {}, False, False)
|
||||
conversation_commands = await aget_relevant_tools_to_execute(user_query, {}, False, False)
|
||||
|
||||
# Assert
|
||||
assert set(expected_conversation_commands) == set(conversation_commands)
|
||||
|
|
|
@ -8,7 +8,7 @@ from freezegun import freeze_time
|
|||
from khoj.database.models import Agent, Entry, KhojUser, LocalPdfConfig
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.utils import message_to_log
|
||||
from khoj.routers.helpers import aget_relevant_information_sources
|
||||
from khoj.routers.helpers import aget_relevant_tools_to_execute
|
||||
from tests.helpers import ConversationFactory
|
||||
|
||||
# Initialize variables for tests
|
||||
|
@ -719,7 +719,7 @@ async def test_get_correct_tools_online(chat_client):
|
|||
user_query = "What's the weather in Patagonia this week?"
|
||||
|
||||
# Act
|
||||
tools = await aget_relevant_information_sources(user_query, {}, False, False)
|
||||
tools = await aget_relevant_tools_to_execute(user_query, {}, False, False)
|
||||
|
||||
# Assert
|
||||
tools = [tool.value for tool in tools]
|
||||
|
@ -734,7 +734,7 @@ async def test_get_correct_tools_notes(chat_client):
|
|||
user_query = "Where did I go for my first battleship training?"
|
||||
|
||||
# Act
|
||||
tools = await aget_relevant_information_sources(user_query, {}, False, False)
|
||||
tools = await aget_relevant_tools_to_execute(user_query, {}, False, False)
|
||||
|
||||
# Assert
|
||||
tools = [tool.value for tool in tools]
|
||||
|
@ -749,7 +749,7 @@ async def test_get_correct_tools_online_or_general_and_notes(chat_client):
|
|||
user_query = "What's the highest point in Patagonia and have I been there?"
|
||||
|
||||
# Act
|
||||
tools = await aget_relevant_information_sources(user_query, {}, False, False)
|
||||
tools = await aget_relevant_tools_to_execute(user_query, {}, False, False)
|
||||
|
||||
# Assert
|
||||
tools = [tool.value for tool in tools]
|
||||
|
@ -766,7 +766,7 @@ async def test_get_correct_tools_general(chat_client):
|
|||
user_query = "How many noble gases are there?"
|
||||
|
||||
# Act
|
||||
tools = await aget_relevant_information_sources(user_query, {}, False, False)
|
||||
tools = await aget_relevant_tools_to_execute(user_query, {}, False, False)
|
||||
|
||||
# Assert
|
||||
tools = [tool.value for tool in tools]
|
||||
|
@ -790,7 +790,7 @@ async def test_get_correct_tools_with_chat_history(chat_client):
|
|||
chat_history = generate_history(chat_log)
|
||||
|
||||
# Act
|
||||
tools = await aget_relevant_information_sources(user_query, chat_history, False, False)
|
||||
tools = await aget_relevant_tools_to_execute(user_query, chat_history, False, False)
|
||||
|
||||
# Assert
|
||||
tools = [tool.value for tool in tools]
|
||||
|
|
Loading…
Reference in a new issue