mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-27 17:35:07 +01:00
Have Khoj dynamically select conversation command(s) in chat (#641)
* Have Khoj dynamically select which conversation command(s) are to be used in the chat flow - Intercept the commands if in default mode, and have Khoj dynamically guess which tools would be the most relevant for answering the user's query * Remove conditional for default to enter online search mode * Add multiple-tool examples in the prompt, make prompt for tools more specific to info collection
This commit is contained in:
parent
69344a6aa6
commit
a3eb17b7d4
10 changed files with 373 additions and 63 deletions
|
@ -479,7 +479,7 @@ class ConversationAdapters:
|
|||
conversation_id: int = None,
|
||||
user_message: str = None,
|
||||
):
|
||||
slug = user_message.strip()[:200] if not is_none_or_empty(user_message) else None
|
||||
slug = user_message.strip()[:200] if user_message else None
|
||||
if conversation_id:
|
||||
conversation = Conversation.objects.filter(user=user, client=client_application, id=conversation_id)
|
||||
else:
|
||||
|
|
|
@ -130,7 +130,7 @@ def converse_offline(
|
|||
model: str = "mistral-7b-instruct-v0.1.Q4_0.gguf",
|
||||
loaded_model: Union[Any, None] = None,
|
||||
completion_func=None,
|
||||
conversation_command=ConversationCommand.Default,
|
||||
conversation_commands=[ConversationCommand.Default],
|
||||
max_prompt_size=None,
|
||||
tokenizer_name=None,
|
||||
) -> Union[ThreadedGenerator, Iterator[str]]:
|
||||
|
@ -148,27 +148,24 @@ def converse_offline(
|
|||
# Initialize Variables
|
||||
compiled_references_message = "\n\n".join({f"{item}" for item in references})
|
||||
|
||||
conversation_primer = prompts.query_prompt.format(query=user_query)
|
||||
|
||||
# Get Conversation Primer appropriate to Conversation Type
|
||||
if conversation_command == ConversationCommand.Notes and is_none_or_empty(compiled_references_message):
|
||||
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(compiled_references_message):
|
||||
return iter([prompts.no_notes_found.format()])
|
||||
elif conversation_command == ConversationCommand.Online and is_none_or_empty(online_results):
|
||||
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
|
||||
completion_func(chat_response=prompts.no_online_results_found.format())
|
||||
return iter([prompts.no_online_results_found.format()])
|
||||
elif conversation_command == ConversationCommand.Online:
|
||||
|
||||
if ConversationCommand.Online in conversation_commands:
|
||||
simplified_online_results = online_results.copy()
|
||||
for result in online_results:
|
||||
if online_results[result].get("extracted_content"):
|
||||
simplified_online_results[result] = online_results[result]["extracted_content"]
|
||||
|
||||
conversation_primer = prompts.online_search_conversation.format(
|
||||
query=user_query, online_results=str(simplified_online_results)
|
||||
)
|
||||
elif conversation_command == ConversationCommand.General or is_none_or_empty(compiled_references_message):
|
||||
conversation_primer = user_query
|
||||
else:
|
||||
conversation_primer = prompts.notes_conversation_gpt4all.format(
|
||||
query=user_query, references=compiled_references_message
|
||||
)
|
||||
conversation_primer = f"{prompts.online_search_conversation.format(online_results=str(simplified_online_results))}\n{conversation_primer}"
|
||||
if ConversationCommand.Notes in conversation_commands:
|
||||
conversation_primer = f"{prompts.notes_conversation_gpt4all.format(references=compiled_references_message)}\n{conversation_primer}"
|
||||
|
||||
# Setup Prompt with Primer or Conversation History
|
||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||
|
|
|
@ -122,7 +122,7 @@ def converse(
|
|||
api_key: Optional[str] = None,
|
||||
temperature: float = 0.2,
|
||||
completion_func=None,
|
||||
conversation_command=ConversationCommand.Default,
|
||||
conversation_commands=[ConversationCommand.Default],
|
||||
max_prompt_size=None,
|
||||
tokenizer_name=None,
|
||||
):
|
||||
|
@ -133,26 +133,25 @@ def converse(
|
|||
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||
compiled_references = "\n\n".join({f"# {item}" for item in references})
|
||||
|
||||
conversation_primer = prompts.query_prompt.format(query=user_query)
|
||||
|
||||
# Get Conversation Primer appropriate to Conversation Type
|
||||
if conversation_command == ConversationCommand.Notes and is_none_or_empty(compiled_references):
|
||||
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(compiled_references):
|
||||
completion_func(chat_response=prompts.no_notes_found.format())
|
||||
return iter([prompts.no_notes_found.format()])
|
||||
elif conversation_command == ConversationCommand.Online and is_none_or_empty(online_results):
|
||||
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
|
||||
completion_func(chat_response=prompts.no_online_results_found.format())
|
||||
return iter([prompts.no_online_results_found.format()])
|
||||
elif conversation_command == ConversationCommand.Online:
|
||||
|
||||
if ConversationCommand.Online in conversation_commands:
|
||||
simplified_online_results = online_results.copy()
|
||||
for result in online_results:
|
||||
if online_results[result].get("extracted_content"):
|
||||
simplified_online_results[result] = online_results[result]["extracted_content"]
|
||||
|
||||
conversation_primer = prompts.online_search_conversation.format(
|
||||
query=user_query, online_results=str(simplified_online_results)
|
||||
)
|
||||
elif conversation_command == ConversationCommand.General or is_none_or_empty(compiled_references):
|
||||
conversation_primer = prompts.general_conversation.format(query=user_query)
|
||||
else:
|
||||
conversation_primer = prompts.notes_conversation.format(query=user_query, references=compiled_references)
|
||||
conversation_primer = f"{prompts.online_search_conversation.format(online_results=str(simplified_online_results))}\n{conversation_primer}"
|
||||
if ConversationCommand.Notes in conversation_commands:
|
||||
conversation_primer = f"{prompts.notes_conversation.format(query=user_query, references=compiled_references)}\n{conversation_primer}"
|
||||
|
||||
# Setup Prompt with Primer or Conversation History
|
||||
messages = generate_chatml_messages_with_context(
|
||||
|
|
|
@ -112,7 +112,6 @@ notes_conversation_gpt4all = PromptTemplate.from_template(
|
|||
"""
|
||||
User's Notes:
|
||||
{references}
|
||||
Question: {query}
|
||||
""".strip()
|
||||
)
|
||||
|
||||
|
@ -139,7 +138,13 @@ Use this up-to-date information from the internet to inform your response.
|
|||
Ask crisp follow-up questions to get additional context, when a helpful response cannot be provided from the online data or past conversations.
|
||||
|
||||
Information from the internet: {online_results}
|
||||
""".strip()
|
||||
)
|
||||
|
||||
## Query prompt
|
||||
## --
|
||||
query_prompt = PromptTemplate.from_template(
|
||||
"""
|
||||
Query: {query}""".strip()
|
||||
)
|
||||
|
||||
|
@ -285,6 +290,60 @@ Collate the relevant information from the website to answer the target query.
|
|||
""".strip()
|
||||
)
|
||||
|
||||
pick_relevant_information_collection_tools = PromptTemplate.from_template(
|
||||
"""
|
||||
You are Khoj, a smart and helpful personal assistant. You have access to a variety of data sources to help you answer the user's question. You can use the data sources listed below to collect more relevant information. You can use any combination of these data sources to answer the user's question. Tell me which data sources you would like to use to answer the user's question.
|
||||
|
||||
{tools}
|
||||
|
||||
Here are some example responses:
|
||||
|
||||
Example 1:
|
||||
Chat History:
|
||||
User: I'm thinking of moving to a new city. I'm trying to decide between New York and San Francisco.
|
||||
AI: Moving to a new city can be challenging. Both New York and San Francisco are great cities to live in. New York is known for its diverse culture and San Francisco is known for its tech scene.
|
||||
|
||||
Q: What is the population of each of those cities?
|
||||
Khoj: ["online"]
|
||||
|
||||
Example 2:
|
||||
Chat History:
|
||||
User: I've been having a hard time at work. I'm thinking of quitting.
|
||||
AI: I'm sorry to hear that. It's important to take care of your mental health. Have you considered talking to your manager about your concerns?
|
||||
|
||||
Q: What are the best ways to quit a job?
|
||||
Khoj: ["general"]
|
||||
|
||||
Example 3:
|
||||
Chat History:
|
||||
User: I'm thinking of my next vacation idea. Ideally, I want to see something new and exciting.
|
||||
AI: Excellent! Taking a vacation is a great way to relax and recharge.
|
||||
|
||||
Q: Where did Grandma grow up?
|
||||
Khoj: ["notes"]
|
||||
|
||||
Example 4:
|
||||
Chat History:
|
||||
|
||||
Q: I want to make chocolate cake. What was my recipe?
|
||||
Khoj: ["notes"]
|
||||
|
||||
Example 5:
|
||||
Chat History:
|
||||
|
||||
Q: What's the latest news with the first company I worked for?
|
||||
Khoj: ["notes", "online"]
|
||||
|
||||
Now it's your turn to pick the tools you would like to use to answer the user's question. Provide your response as a list of strings.
|
||||
|
||||
Chat History:
|
||||
{chat_history}
|
||||
|
||||
Q: {query}
|
||||
A:
|
||||
""".strip()
|
||||
)
|
||||
|
||||
online_search_conversation_subqueries = PromptTemplate.from_template(
|
||||
"""
|
||||
You are Khoj, an extremely smart and helpful search assistant. You are tasked with constructing **up to three** search queries for Google to answer the user's question.
|
||||
|
|
|
@ -274,7 +274,7 @@ async def extract_references_and_questions(
|
|||
q: str,
|
||||
n: int,
|
||||
d: float,
|
||||
conversation_type: ConversationCommand = ConversationCommand.Default,
|
||||
conversation_commands: List[ConversationCommand] = [ConversationCommand.Default],
|
||||
):
|
||||
user = request.user.object if request.user.is_authenticated else None
|
||||
|
||||
|
@ -282,7 +282,7 @@ async def extract_references_and_questions(
|
|||
compiled_references: List[Any] = []
|
||||
inferred_queries: List[str] = []
|
||||
|
||||
if conversation_type == ConversationCommand.General or conversation_type == ConversationCommand.Online:
|
||||
if not ConversationCommand.Notes in conversation_commands:
|
||||
return compiled_references, inferred_queries, q
|
||||
|
||||
if not await sync_to_async(EntryAdapters.user_has_entries)(user=user):
|
||||
|
|
|
@ -21,6 +21,7 @@ from khoj.routers.helpers import (
|
|||
CommonQueryParams,
|
||||
ConversationCommandRateLimiter,
|
||||
agenerate_chat_response,
|
||||
aget_relevant_information_sources,
|
||||
get_conversation_command,
|
||||
is_ready_to_chat,
|
||||
text_to_image,
|
||||
|
@ -207,7 +208,7 @@ async def set_conversation_title(
|
|||
)
|
||||
|
||||
|
||||
@api_chat.get("", response_class=Response)
|
||||
@api_chat.get("/", response_class=Response)
|
||||
@requires(["authenticated"])
|
||||
async def chat(
|
||||
request: Request,
|
||||
|
@ -229,25 +230,9 @@ async def chat(
|
|||
q = unquote(q)
|
||||
|
||||
await is_ready_to_chat(user)
|
||||
conversation_command = get_conversation_command(query=q, any_references=True)
|
||||
conversation_commands = [get_conversation_command(query=q, any_references=True)]
|
||||
|
||||
await conversation_command_rate_limiter.update_and_check_if_valid(request, conversation_command)
|
||||
|
||||
q = q.replace(f"/{conversation_command.value}", "").strip()
|
||||
|
||||
meta_log = (
|
||||
await ConversationAdapters.aget_conversation_by_user(user, request.user.client_app, conversation_id, slug)
|
||||
).conversation_log
|
||||
|
||||
compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
|
||||
request, common, meta_log, q, (n or 5), (d or math.inf), conversation_command
|
||||
)
|
||||
online_results: Dict = dict()
|
||||
|
||||
if conversation_command == ConversationCommand.Default and is_none_or_empty(compiled_references):
|
||||
conversation_command = ConversationCommand.General
|
||||
|
||||
elif conversation_command == ConversationCommand.Help:
|
||||
if conversation_commands == [ConversationCommand.Help]:
|
||||
conversation_config = await ConversationAdapters.aget_user_conversation_config(user)
|
||||
if conversation_config == None:
|
||||
conversation_config = await ConversationAdapters.aget_default_conversation_config()
|
||||
|
@ -255,7 +240,23 @@ async def chat(
|
|||
formatted_help = help_message.format(model=model_type, version=state.khoj_version, device=get_device())
|
||||
return StreamingResponse(iter([formatted_help]), media_type="text/event-stream", status_code=200)
|
||||
|
||||
elif conversation_command == ConversationCommand.Notes and not await EntryAdapters.auser_has_entries(user):
|
||||
meta_log = (
|
||||
await ConversationAdapters.aget_conversation_by_user(user, request.user.client_app, conversation_id, slug)
|
||||
).conversation_log
|
||||
|
||||
if conversation_commands == [ConversationCommand.Default]:
|
||||
conversation_commands = await aget_relevant_information_sources(q, meta_log)
|
||||
|
||||
for cmd in conversation_commands:
|
||||
await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd)
|
||||
q = q.replace(f"/{cmd.value}", "").strip()
|
||||
|
||||
compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
|
||||
request, common, meta_log, q, (n or 5), (d or math.inf), conversation_commands
|
||||
)
|
||||
online_results: Dict = dict()
|
||||
|
||||
if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user):
|
||||
no_entries_found_format = no_entries_found.format()
|
||||
if stream:
|
||||
return StreamingResponse(iter([no_entries_found_format]), media_type="text/event-stream", status_code=200)
|
||||
|
@ -263,7 +264,10 @@ async def chat(
|
|||
response_obj = {"response": no_entries_found_format}
|
||||
return Response(content=json.dumps(response_obj), media_type="text/plain", status_code=200)
|
||||
|
||||
elif conversation_command == ConversationCommand.Online:
|
||||
if ConversationCommand.Notes in conversation_commands and is_none_or_empty(compiled_references):
|
||||
conversation_commands.remove(ConversationCommand.Notes)
|
||||
|
||||
if ConversationCommand.Online in conversation_commands:
|
||||
try:
|
||||
online_results = await search_with_google(defiltered_query, meta_log)
|
||||
except ValueError as e:
|
||||
|
@ -272,12 +276,12 @@ async def chat(
|
|||
media_type="text/event-stream",
|
||||
status_code=200,
|
||||
)
|
||||
elif conversation_command == ConversationCommand.Image:
|
||||
elif conversation_commands == [ConversationCommand.Image]:
|
||||
update_telemetry_state(
|
||||
request=request,
|
||||
telemetry_type="api",
|
||||
api="chat",
|
||||
metadata={"conversation_command": conversation_command.value},
|
||||
metadata={"conversation_command": conversation_commands[0].value},
|
||||
**common.__dict__,
|
||||
)
|
||||
image, status_code, improved_image_prompt = await text_to_image(q, meta_log)
|
||||
|
@ -308,13 +312,13 @@ async def chat(
|
|||
compiled_references,
|
||||
online_results,
|
||||
inferred_queries,
|
||||
conversation_command,
|
||||
conversation_commands,
|
||||
user,
|
||||
request.user.client_app,
|
||||
conversation_id,
|
||||
)
|
||||
|
||||
chat_metadata.update({"conversation_command": conversation_command.value})
|
||||
chat_metadata.update({"conversation_command": ",".join([cmd.value for cmd in conversation_commands])})
|
||||
|
||||
update_telemetry_state(
|
||||
request=request,
|
||||
|
|
|
@ -34,7 +34,11 @@ from khoj.processor.conversation.utils import (
|
|||
)
|
||||
from khoj.utils import state
|
||||
from khoj.utils.config import GPT4AllProcessorModel
|
||||
from khoj.utils.helpers import ConversationCommand, log_telemetry
|
||||
from khoj.utils.helpers import (
|
||||
ConversationCommand,
|
||||
log_telemetry,
|
||||
tool_descriptions_for_llm,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -105,6 +109,15 @@ def update_telemetry_state(
|
|||
]
|
||||
|
||||
|
||||
def construct_chat_history(conversation_history: dict, n: int = 4) -> str:
|
||||
chat_history = ""
|
||||
for chat in conversation_history.get("chat", [])[-n:]:
|
||||
if chat["by"] == "khoj" and chat["intent"].get("type") == "remember":
|
||||
chat_history += f"User: {chat['intent']['query']}\n"
|
||||
chat_history += f"Khoj: {chat['message']}\n"
|
||||
return chat_history
|
||||
|
||||
|
||||
def get_conversation_command(query: str, any_references: bool = False) -> ConversationCommand:
|
||||
if query.startswith("/notes"):
|
||||
return ConversationCommand.Notes
|
||||
|
@ -128,15 +141,50 @@ async def agenerate_chat_response(*args):
|
|||
return await loop.run_in_executor(executor, generate_chat_response, *args)
|
||||
|
||||
|
||||
async def aget_relevant_information_sources(query: str, conversation_history: dict):
|
||||
"""
|
||||
Given a query, determine which of the available tools the agent should use in order to answer appropriately.
|
||||
"""
|
||||
|
||||
tool_options = dict()
|
||||
|
||||
for tool, description in tool_descriptions_for_llm.items():
|
||||
tool_options[tool.value] = description
|
||||
|
||||
chat_history = construct_chat_history(conversation_history)
|
||||
|
||||
relevant_tools_prompt = prompts.pick_relevant_information_collection_tools.format(
|
||||
query=query,
|
||||
tools=str(tool_options),
|
||||
chat_history=chat_history,
|
||||
)
|
||||
|
||||
response = await send_message_to_model_wrapper(relevant_tools_prompt)
|
||||
|
||||
try:
|
||||
response = response.strip()
|
||||
response = json.loads(response)
|
||||
response = [q.strip() for q in response if q.strip()]
|
||||
if not isinstance(response, list) or not response or len(response) == 0:
|
||||
logger.error(f"Invalid response for determining relevant tools: {response}")
|
||||
return tool_options
|
||||
|
||||
final_response = []
|
||||
for llm_suggested_tool in response:
|
||||
if llm_suggested_tool in tool_options.keys():
|
||||
# Check whether the tool exists as a valid ConversationCommand
|
||||
final_response.append(ConversationCommand(llm_suggested_tool))
|
||||
return final_response
|
||||
except Exception as e:
|
||||
logger.error(f"Invalid response for determining relevant tools: {response}")
|
||||
return [ConversationCommand.Default]
|
||||
|
||||
|
||||
async def generate_online_subqueries(q: str, conversation_history: dict) -> List[str]:
|
||||
"""
|
||||
Generate subqueries from the given query
|
||||
"""
|
||||
chat_history = ""
|
||||
for chat in conversation_history.get("chat", [])[-4:]:
|
||||
if chat["by"] == "khoj" and chat["intent"].get("type") == "remember":
|
||||
chat_history += f"User: {chat['intent']['query']}\n"
|
||||
chat_history += f"Khoj: {chat['message']}\n"
|
||||
chat_history = construct_chat_history(conversation_history)
|
||||
|
||||
utc_date = datetime.utcnow().strftime("%Y-%m-%d")
|
||||
online_queries_prompt = prompts.online_search_conversation_subqueries.format(
|
||||
|
@ -241,14 +289,14 @@ def generate_chat_response(
|
|||
compiled_references: List[str] = [],
|
||||
online_results: Dict[str, Any] = {},
|
||||
inferred_queries: List[str] = [],
|
||||
conversation_command: ConversationCommand = ConversationCommand.Default,
|
||||
conversation_commands: List[ConversationCommand] = [ConversationCommand.Default],
|
||||
user: KhojUser = None,
|
||||
client_application: ClientApplication = None,
|
||||
conversation_id: int = None,
|
||||
) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]:
|
||||
# Initialize Variables
|
||||
chat_response = None
|
||||
logger.debug(f"Conversation Type: {conversation_command.name}")
|
||||
logger.debug(f"Conversation Types: {conversation_commands}")
|
||||
|
||||
metadata = {}
|
||||
|
||||
|
@ -278,7 +326,7 @@ def generate_chat_response(
|
|||
loaded_model=loaded_model,
|
||||
conversation_log=meta_log,
|
||||
completion_func=partial_completion,
|
||||
conversation_command=conversation_command,
|
||||
conversation_commands=conversation_commands,
|
||||
model=conversation_config.chat_model,
|
||||
max_prompt_size=conversation_config.max_prompt_size,
|
||||
tokenizer_name=conversation_config.tokenizer,
|
||||
|
@ -296,7 +344,7 @@ def generate_chat_response(
|
|||
model=chat_model,
|
||||
api_key=api_key,
|
||||
completion_func=partial_completion,
|
||||
conversation_command=conversation_command,
|
||||
conversation_commands=conversation_commands,
|
||||
max_prompt_size=conversation_config.max_prompt_size,
|
||||
tokenizer_name=conversation_config.tokenizer,
|
||||
)
|
||||
|
|
|
@ -282,6 +282,13 @@ command_descriptions = {
|
|||
ConversationCommand.Help: "Display a help message with all available commands and other metadata.",
|
||||
}
|
||||
|
||||
tool_descriptions_for_llm = {
|
||||
ConversationCommand.Default: "Use this if there might be a mix of general and personal knowledge in the question",
|
||||
ConversationCommand.General: "Use this when you can answer the question without needing any additional online or personal information",
|
||||
ConversationCommand.Notes: "Use this when you would like to use the user's personal knowledge base to answer the question",
|
||||
ConversationCommand.Online: "Use this when you would like to look up information on the internet",
|
||||
}
|
||||
|
||||
|
||||
def generate_random_name():
|
||||
# List of adjectives and nouns to choose from
|
||||
|
|
|
@ -8,6 +8,7 @@ from freezegun import freeze_time
|
|||
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.utils import message_to_log
|
||||
from khoj.routers.helpers import aget_relevant_information_sources
|
||||
from tests.helpers import ConversationFactory
|
||||
|
||||
SKIP_TESTS = True
|
||||
|
@ -440,3 +441,100 @@ def test_answer_requires_multiple_independent_searches(client_offline_chat):
|
|||
assert any([expected_response in response_message.lower() for expected_response in expected_responses]), (
|
||||
"Expected Xi is older than Namita, but got: " + response_message
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
async def test_get_correct_tools_online(client_offline_chat):
|
||||
# Arrange
|
||||
user_query = "What's the weather in Patagonia this week?"
|
||||
|
||||
# Act
|
||||
tools = await aget_relevant_information_sources(user_query, {})
|
||||
|
||||
# Assert
|
||||
tools = [tool.value for tool in tools]
|
||||
assert tools == ["online"]
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
async def test_get_correct_tools_notes(client_offline_chat):
|
||||
# Arrange
|
||||
user_query = "Where did I go for my first battleship training?"
|
||||
|
||||
# Act
|
||||
tools = await aget_relevant_information_sources(user_query, {})
|
||||
|
||||
# Assert
|
||||
tools = [tool.value for tool in tools]
|
||||
assert tools == ["notes"]
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
async def test_get_correct_tools_online_or_general_and_notes(client_offline_chat):
|
||||
# Arrange
|
||||
user_query = "What's the highest point in Patagonia and have I been there?"
|
||||
|
||||
# Act
|
||||
tools = await aget_relevant_information_sources(user_query, {})
|
||||
|
||||
# Assert
|
||||
tools = [tool.value for tool in tools]
|
||||
assert len(tools) == 2
|
||||
assert "online" or "general" in tools
|
||||
assert "notes" in tools
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
async def test_get_correct_tools_general(client_offline_chat):
|
||||
# Arrange
|
||||
user_query = "How many noble gases are there?"
|
||||
|
||||
# Act
|
||||
tools = await aget_relevant_information_sources(user_query, {})
|
||||
|
||||
# Assert
|
||||
tools = [tool.value for tool in tools]
|
||||
assert tools == ["general"]
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
async def test_get_correct_tools_with_chat_history(client_offline_chat):
|
||||
# Arrange
|
||||
user_query = "What's the latest in the Israel/Palestine conflict?"
|
||||
chat_log = [
|
||||
(
|
||||
"Let's talk about the current events around the world.",
|
||||
"Sure, let's discuss the current events. What would you like to know?",
|
||||
),
|
||||
("What's up in New York City?", "A Pride parade has recently been held in New York City, on July 31st."),
|
||||
]
|
||||
chat_history = populate_chat_history(chat_log)
|
||||
|
||||
# Act
|
||||
tools = await aget_relevant_information_sources(user_query, chat_history)
|
||||
|
||||
# Assert
|
||||
tools = [tool.value for tool in tools]
|
||||
assert tools == ["online"]
|
||||
|
||||
|
||||
def populate_chat_history(message_list):
|
||||
# Generate conversation logs
|
||||
conversation_log = {"chat": []}
|
||||
for user_message, gpt_message in message_list:
|
||||
conversation_log["chat"] += message_to_log(
|
||||
user_message,
|
||||
gpt_message,
|
||||
{"context": [], "intent": {"query": user_message, "inferred-queries": f'["{user_message}"]'}},
|
||||
)
|
||||
return conversation_log
|
||||
|
|
|
@ -8,6 +8,7 @@ from freezegun import freeze_time
|
|||
from khoj.database.models import KhojUser
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.utils import message_to_log
|
||||
from khoj.routers.helpers import aget_relevant_information_sources
|
||||
from tests.helpers import ConversationFactory
|
||||
|
||||
# Initialize variables for tests
|
||||
|
@ -416,3 +417,100 @@ def test_answer_using_file_filter(chat_client):
|
|||
assert any([expected_response in response_message.lower() for expected_response in expected_responses]), (
|
||||
"Expected Xi is older than Namita, but got: " + response_message
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
async def test_get_correct_tools_online(chat_client):
|
||||
# Arrange
|
||||
user_query = "What's the weather in Patagonia this week?"
|
||||
|
||||
# Act
|
||||
tools = await aget_relevant_information_sources(user_query, {})
|
||||
|
||||
# Assert
|
||||
tools = [tool.value for tool in tools]
|
||||
assert tools == ["online"]
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
async def test_get_correct_tools_notes(chat_client):
|
||||
# Arrange
|
||||
user_query = "Where did I go for my first battleship training?"
|
||||
|
||||
# Act
|
||||
tools = await aget_relevant_information_sources(user_query, {})
|
||||
|
||||
# Assert
|
||||
tools = [tool.value for tool in tools]
|
||||
assert tools == ["notes"]
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
async def test_get_correct_tools_online_or_general_and_notes(chat_client):
|
||||
# Arrange
|
||||
user_query = "What's the highest point in Patagonia and have I been there?"
|
||||
|
||||
# Act
|
||||
tools = await aget_relevant_information_sources(user_query, {})
|
||||
|
||||
# Assert
|
||||
tools = [tool.value for tool in tools]
|
||||
assert len(tools) == 2
|
||||
assert "online" in tools or "general" in tools
|
||||
assert "notes" in tools
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
async def test_get_correct_tools_general(chat_client):
|
||||
# Arrange
|
||||
user_query = "How many noble gases are there?"
|
||||
|
||||
# Act
|
||||
tools = await aget_relevant_information_sources(user_query, {})
|
||||
|
||||
# Assert
|
||||
tools = [tool.value for tool in tools]
|
||||
assert tools == ["general"]
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
async def test_get_correct_tools_with_chat_history(chat_client):
|
||||
# Arrange
|
||||
user_query = "What's the latest in the Israel/Palestine conflict?"
|
||||
chat_log = [
|
||||
(
|
||||
"Let's talk about the current events around the world.",
|
||||
"Sure, let's discuss the current events. What would you like to know?",
|
||||
),
|
||||
("What's up in New York City?", "A Pride parade has recently been held in New York City, on July 31st."),
|
||||
]
|
||||
chat_history = populate_chat_history(chat_log)
|
||||
|
||||
# Act
|
||||
tools = await aget_relevant_information_sources(user_query, chat_history)
|
||||
|
||||
# Assert
|
||||
tools = [tool.value for tool in tools]
|
||||
assert tools == ["online"]
|
||||
|
||||
|
||||
def populate_chat_history(message_list):
|
||||
# Generate conversation logs
|
||||
conversation_log = {"chat": []}
|
||||
for user_message, gpt_message in message_list:
|
||||
conversation_log["chat"] += message_to_log(
|
||||
user_message,
|
||||
gpt_message,
|
||||
{"context": [], "intent": {"query": user_message, "inferred-queries": f'["{user_message}"]'}},
|
||||
)
|
||||
return conversation_log
|
||||
|
|
Loading…
Reference in a new issue