Improve Document, Online Search to Answer Vague or Meta Questions (#870)

- Major
  - Improve doc search actor performance on vague, random or meta questions
  - Pass user's name to document and online search actors prompts

- Minor
  - Fix and improve openai chat actor tests
  - Remove unused max tokns arg to extract qs func of doc search actor
This commit is contained in:
Debanjum 2024-08-16 06:46:13 -07:00 committed by GitHub
commit 39e566ba91
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 96 additions and 31 deletions

View file

@ -6,7 +6,7 @@ from typing import Dict, Optional
from langchain.schema import ChatMessage from langchain.schema import ChatMessage
from khoj.database.models import Agent from khoj.database.models import Agent, KhojUser
from khoj.processor.conversation import prompts from khoj.processor.conversation import prompts
from khoj.processor.conversation.anthropic.utils import ( from khoj.processor.conversation.anthropic.utils import (
anthropic_chat_completion_with_backoff, anthropic_chat_completion_with_backoff,
@ -24,14 +24,16 @@ def extract_questions_anthropic(
model: Optional[str] = "claude-instant-1.2", model: Optional[str] = "claude-instant-1.2",
conversation_log={}, conversation_log={},
api_key=None, api_key=None,
temperature=0, temperature=0.7,
location_data: LocationData = None, location_data: LocationData = None,
user: KhojUser = None,
): ):
""" """
Infer search queries to retrieve relevant notes to answer user query Infer search queries to retrieve relevant notes to answer user query
""" """
# Extract Past User Message and Inferred Questions from Conversation Log # Extract Past User Message and Inferred Questions from Conversation Log
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown" location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown"
username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else ""
# Extract Past User Message and Inferred Questions from Conversation Log # Extract Past User Message and Inferred Questions from Conversation Log
chat_history = "".join( chat_history = "".join(
@ -50,11 +52,13 @@ def extract_questions_anthropic(
system_prompt = prompts.extract_questions_anthropic_system_prompt.format( system_prompt = prompts.extract_questions_anthropic_system_prompt.format(
current_date=today.strftime("%Y-%m-%d"), current_date=today.strftime("%Y-%m-%d"),
day_of_week=today.strftime("%A"), day_of_week=today.strftime("%A"),
current_month=today.strftime("%Y-%m"),
last_new_year=last_new_year.strftime("%Y"), last_new_year=last_new_year.strftime("%Y"),
last_new_year_date=last_new_year.strftime("%Y-%m-%d"), last_new_year_date=last_new_year.strftime("%Y-%m-%d"),
current_new_year_date=current_new_year.strftime("%Y-%m-%d"), current_new_year_date=current_new_year.strftime("%Y-%m-%d"),
yesterday_date=(today - timedelta(days=1)).strftime("%Y-%m-%d"), yesterday_date=(today - timedelta(days=1)).strftime("%Y-%m-%d"),
location=location, location=location,
username=username,
) )
prompt = prompts.extract_questions_anthropic_user_message.format( prompt = prompts.extract_questions_anthropic_user_message.format(

View file

@ -7,7 +7,7 @@ from typing import Any, Iterator, List, Union
from langchain.schema import ChatMessage from langchain.schema import ChatMessage
from llama_cpp import Llama from llama_cpp import Llama
from khoj.database.models import Agent from khoj.database.models import Agent, KhojUser
from khoj.processor.conversation import prompts from khoj.processor.conversation import prompts
from khoj.processor.conversation.offline.utils import download_model from khoj.processor.conversation.offline.utils import download_model
from khoj.processor.conversation.utils import ( from khoj.processor.conversation.utils import (
@ -30,7 +30,9 @@ def extract_questions_offline(
use_history: bool = True, use_history: bool = True,
should_extract_questions: bool = True, should_extract_questions: bool = True,
location_data: LocationData = None, location_data: LocationData = None,
user: KhojUser = None,
max_prompt_size: int = None, max_prompt_size: int = None,
temperature: float = 0.7,
) -> List[str]: ) -> List[str]:
""" """
Infer search queries to retrieve relevant notes to answer user query Infer search queries to retrieve relevant notes to answer user query
@ -45,6 +47,7 @@ def extract_questions_offline(
offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size) offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown" location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown"
username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else ""
# Extract Past User Message and Inferred Questions from Conversation Log # Extract Past User Message and Inferred Questions from Conversation Log
chat_history = "" chat_history = ""
@ -64,10 +67,12 @@ def extract_questions_offline(
chat_history=chat_history, chat_history=chat_history,
current_date=today.strftime("%Y-%m-%d"), current_date=today.strftime("%Y-%m-%d"),
day_of_week=today.strftime("%A"), day_of_week=today.strftime("%A"),
current_month=today.strftime("%Y-%m"),
yesterday_date=yesterday, yesterday_date=yesterday,
last_year=last_year, last_year=last_year,
this_year=today.year, this_year=today.year,
location=location, location=location,
username=username,
) )
messages = generate_chatml_messages_with_context( messages = generate_chatml_messages_with_context(
@ -77,7 +82,11 @@ def extract_questions_offline(
state.chat_lock.acquire() state.chat_lock.acquire()
try: try:
response = send_message_to_model_offline( response = send_message_to_model_offline(
messages, loaded_model=offline_chat_model, model=model, max_prompt_size=max_prompt_size messages,
loaded_model=offline_chat_model,
model=model,
max_prompt_size=max_prompt_size,
temperature=temperature,
) )
finally: finally:
state.chat_lock.release() state.chat_lock.release()
@ -229,6 +238,7 @@ def send_message_to_model_offline(
messages: List[ChatMessage], messages: List[ChatMessage],
loaded_model=None, loaded_model=None,
model="NousResearch/Hermes-2-Pro-Mistral-7B-GGUF", model="NousResearch/Hermes-2-Pro-Mistral-7B-GGUF",
temperature: float = 0.2,
streaming=False, streaming=False,
stop=[], stop=[],
max_prompt_size: int = None, max_prompt_size: int = None,
@ -236,7 +246,9 @@ def send_message_to_model_offline(
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured" assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size) offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
messages_dict = [{"role": message.role, "content": message.content} for message in messages] messages_dict = [{"role": message.role, "content": message.content} for message in messages]
response = offline_chat_model.create_chat_completion(messages_dict, stop=stop, stream=streaming) response = offline_chat_model.create_chat_completion(
messages_dict, stop=stop, stream=streaming, temperature=temperature
)
if streaming: if streaming:
return response return response
else: else:

View file

@ -5,7 +5,7 @@ from typing import Dict, Optional
from langchain.schema import ChatMessage from langchain.schema import ChatMessage
from khoj.database.models import Agent from khoj.database.models import Agent, KhojUser
from khoj.processor.conversation import prompts from khoj.processor.conversation import prompts
from khoj.processor.conversation.openai.utils import ( from khoj.processor.conversation.openai.utils import (
chat_completion_with_backoff, chat_completion_with_backoff,
@ -24,14 +24,15 @@ def extract_questions(
conversation_log={}, conversation_log={},
api_key=None, api_key=None,
api_base_url=None, api_base_url=None,
temperature=0, temperature=0.7,
max_tokens=100,
location_data: LocationData = None, location_data: LocationData = None,
user: KhojUser = None,
): ):
""" """
Infer search queries to retrieve relevant notes to answer user query Infer search queries to retrieve relevant notes to answer user query
""" """
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown" location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown"
username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else ""
# Extract Past User Message and Inferred Questions from Conversation Log # Extract Past User Message and Inferred Questions from Conversation Log
chat_history = "".join( chat_history = "".join(
@ -50,6 +51,7 @@ def extract_questions(
prompt = prompts.extract_questions.format( prompt = prompts.extract_questions.format(
current_date=today.strftime("%Y-%m-%d"), current_date=today.strftime("%Y-%m-%d"),
day_of_week=today.strftime("%A"), day_of_week=today.strftime("%A"),
current_month=today.strftime("%Y-%m"),
last_new_year=last_new_year.strftime("%Y"), last_new_year=last_new_year.strftime("%Y"),
last_new_year_date=last_new_year.strftime("%Y-%m-%d"), last_new_year_date=last_new_year.strftime("%Y-%m-%d"),
current_new_year_date=current_new_year.strftime("%Y-%m-%d"), current_new_year_date=current_new_year.strftime("%Y-%m-%d"),
@ -59,6 +61,7 @@ def extract_questions(
text=text, text=text,
yesterday_date=(today - timedelta(days=1)).strftime("%Y-%m-%d"), yesterday_date=(today - timedelta(days=1)).strftime("%Y-%m-%d"),
location=location, location=location,
username=username,
) )
messages = [ChatMessage(content=prompt, role="user")] messages = [ChatMessage(content=prompt, role="user")]

View file

@ -36,7 +36,7 @@ def completion_with_backoff(
messages, model, temperature=0, openai_api_key=None, api_base_url=None, model_kwargs=None messages, model, temperature=0, openai_api_key=None, api_base_url=None, model_kwargs=None
) -> str: ) -> str:
client_key = f"{openai_api_key}--{api_base_url}" client_key = f"{openai_api_key}--{api_base_url}"
client: openai.OpenAI = openai_clients.get(client_key) client: openai.OpenAI | None = openai_clients.get(client_key)
if not client: if not client:
client = openai.OpenAI( client = openai.OpenAI(
api_key=openai_api_key, api_key=openai_api_key,

View file

@ -208,10 +208,12 @@ Construct search queries to retrieve relevant information to answer the user's q
- Add as much context from the previous questions and answers as required into your search queries. - Add as much context from the previous questions and answers as required into your search queries.
- Break messages into multiple search queries when required to retrieve the relevant information. - Break messages into multiple search queries when required to retrieve the relevant information.
- Add date filters to your search queries from questions and answers when required to retrieve the relevant information. - Add date filters to your search queries from questions and answers when required to retrieve the relevant information.
- When asked a meta, vague or random questions, search for a variety of broad topics to answer the user's question.
- Share relevant search queries as a JSON list of strings. Do not say anything else. - Share relevant search queries as a JSON list of strings. Do not say anything else.
Current Date: {day_of_week}, {current_date} Current Date: {day_of_week}, {current_date}
User's Location: {location} User's Location: {location}
{username}
Examples: Examples:
Q: How was my trip to Cambodia? Q: How was my trip to Cambodia?
@ -238,6 +240,9 @@ Khoj: ["What kind of plants do I have?", "What issues do my plants have?"]
Q: Who all did I meet here yesterday? Q: Who all did I meet here yesterday?
Khoj: ["Met in {location} on {yesterday_date} dt>='{yesterday_date}' dt<'{current_date}'"] Khoj: ["Met in {location} on {yesterday_date} dt>='{yesterday_date}' dt<'{current_date}'"]
Q: Share some random, interesting experiences from this month
Khoj: ["Exciting travel adventures from {current_month}", "Fun social events dt>='{current_month}-01' dt<'{current_date}'", "Intense emotional experiences in {current_month}"]
Chat History: Chat History:
{chat_history} {chat_history}
What searches will you perform to answer the following question, using the chat history as reference? Respond only with relevant search queries as a valid JSON list of strings. What searches will you perform to answer the following question, using the chat history as reference? Respond only with relevant search queries as a valid JSON list of strings.
@ -254,10 +259,12 @@ Construct search queries to retrieve relevant information to answer the user's q
- Add as much context from the previous questions and answers as required into your search queries. - Add as much context from the previous questions and answers as required into your search queries.
- Break messages into multiple search queries when required to retrieve the relevant information. - Break messages into multiple search queries when required to retrieve the relevant information.
- Add date filters to your search queries from questions and answers when required to retrieve the relevant information. - Add date filters to your search queries from questions and answers when required to retrieve the relevant information.
- When asked a meta, vague or random questions, search for a variety of broad topics to answer the user's question.
What searches will you perform to answer the users question? Respond with search queries as list of strings in a JSON object. What searches will you perform to answer the users question? Respond with search queries as list of strings in a JSON object.
Current Date: {day_of_week}, {current_date} Current Date: {day_of_week}, {current_date}
User's Location: {location} User's Location: {location}
{username}
Q: How was my trip to Cambodia? Q: How was my trip to Cambodia?
Khoj: {{"queries": ["How was my trip to Cambodia?"]}} Khoj: {{"queries": ["How was my trip to Cambodia?"]}}
@ -279,6 +286,10 @@ Q: How many tennis balls fit in the back of a 2002 Honda Civic?
Khoj: {{"queries": ["What is the size of a tennis ball?", "What is the trunk size of a 2002 Honda Civic?"]}} Khoj: {{"queries": ["What is the size of a tennis ball?", "What is the trunk size of a 2002 Honda Civic?"]}}
A: 1085 tennis balls will fit in the trunk of a Honda Civic A: 1085 tennis balls will fit in the trunk of a Honda Civic
Q: Share some random, interesting experiences from this month
Khoj: {{"queries": ["Exciting travel adventures from {current_month}", "Fun social events dt>='{current_month}-01' dt<'{current_date}'", "Intense emotional experiences in {current_month}"]}}
A: You had a great time at the local beach with your friends, attended a music concert and had a deep conversation with your friend, Khalid.
Q: Is Bob older than Tom? Q: Is Bob older than Tom?
Khoj: {{"queries": ["When was Bob born?", "What is Tom's age?"]}} Khoj: {{"queries": ["When was Bob born?", "What is Tom's age?"]}}
A: Yes, Bob is older than Tom. As Bob was born on 1984-01-01 and Tom is 30 years old. A: Yes, Bob is older than Tom. As Bob was born on 1984-01-01 and Tom is 30 years old.
@ -305,11 +316,13 @@ Construct search queries to retrieve relevant information to answer the user's q
- Add as much context from the previous questions and answers as required into your search queries. - Add as much context from the previous questions and answers as required into your search queries.
- Break messages into multiple search queries when required to retrieve the relevant information. - Break messages into multiple search queries when required to retrieve the relevant information.
- Add date filters to your search queries from questions and answers when required to retrieve the relevant information. - Add date filters to your search queries from questions and answers when required to retrieve the relevant information.
- When asked a meta, vague or random questions, search for a variety of broad topics to answer the user's question.
What searches will you perform to answer the users question? Respond with a JSON object with the key "queries" mapping to a list of searches you would perform on the user's knowledge base. Just return the queries and nothing else. What searches will you perform to answer the users question? Respond with a JSON object with the key "queries" mapping to a list of searches you would perform on the user's knowledge base. Just return the queries and nothing else.
Current Date: {day_of_week}, {current_date} Current Date: {day_of_week}, {current_date}
User's Location: {location} User's Location: {location}
{username}
Here are some examples of how you can construct search queries to answer the user's question: Here are some examples of how you can construct search queries to answer the user's question:
@ -328,6 +341,11 @@ A: I can help you live healthier and happier across work and personal life
User: Who all did I meet here yesterday? User: Who all did I meet here yesterday?
Assistant: {{"queries": ["Met in {location} on {yesterday_date} dt>='{yesterday_date}' dt<'{current_date}'"]}} Assistant: {{"queries": ["Met in {location} on {yesterday_date} dt>='{yesterday_date}' dt<'{current_date}'"]}}
A: Yesterday's note mentions your visit to your local beach with Ram and Shyam. A: Yesterday's note mentions your visit to your local beach with Ram and Shyam.
User: Share some random, interesting experiences from this month
Assistant: {{"queries": ["Exciting travel adventures from {current_month}", "Fun social events dt>='{current_month}-01' dt<'{current_date}'", "Intense emotional experiences in {current_month}"]}}
A: You had a great time at the local beach with your friends, attended a music concert and had a deep conversation with your friend, Khalid.
""".strip() """.strip()
) )
@ -525,6 +543,7 @@ Which webpages will you need to read to answer the user's question?
Provide web page links as a list of strings in a JSON object. Provide web page links as a list of strings in a JSON object.
Current Date: {current_date} Current Date: {current_date}
User's Location: {location} User's Location: {location}
{username}
Here are some examples: Here are some examples:
History: History:
@ -571,6 +590,7 @@ What Google searches, if any, will you need to perform to answer the user's ques
Provide search queries as a list of strings in a JSON object. Do not wrap the json in a codeblock. Provide search queries as a list of strings in a JSON object. Do not wrap the json in a codeblock.
Current Date: {current_date} Current Date: {current_date}
User's Location: {location} User's Location: {location}
{username}
Here are some examples: Here are some examples:
History: History:

View file

@ -10,6 +10,7 @@ import aiohttp
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
from markdownify import markdownify from markdownify import markdownify
from khoj.database.models import KhojUser
from khoj.routers.helpers import ( from khoj.routers.helpers import (
ChatEvent, ChatEvent,
extract_relevant_info, extract_relevant_info,
@ -51,6 +52,7 @@ async def search_online(
query: str, query: str,
conversation_history: dict, conversation_history: dict,
location: LocationData, location: LocationData,
user: KhojUser,
send_status_func: Optional[Callable] = None, send_status_func: Optional[Callable] = None,
custom_filters: List[str] = [], custom_filters: List[str] = [],
): ):
@ -61,7 +63,7 @@ async def search_online(
return return
# Breakdown the query into subqueries to get the correct answer # Breakdown the query into subqueries to get the correct answer
subqueries = await generate_online_subqueries(query, conversation_history, location) subqueries = await generate_online_subqueries(query, conversation_history, location, user)
response_dict = {} response_dict = {}
if subqueries: if subqueries:
@ -126,14 +128,18 @@ async def search_with_google(query: str) -> Tuple[str, Dict[str, List[Dict]]]:
async def read_webpages( async def read_webpages(
query: str, conversation_history: dict, location: LocationData, send_status_func: Optional[Callable] = None query: str,
conversation_history: dict,
location: LocationData,
user: KhojUser,
send_status_func: Optional[Callable] = None,
): ):
"Infer web pages to read from the query and extract relevant information from them" "Infer web pages to read from the query and extract relevant information from them"
logger.info(f"Inferring web pages to read") logger.info(f"Inferring web pages to read")
if send_status_func: if send_status_func:
async for event in send_status_func(f"**Inferring web pages to read**"): async for event in send_status_func(f"**Inferring web pages to read**"):
yield {ChatEvent.STATUS: event} yield {ChatEvent.STATUS: event}
urls = await infer_webpage_urls(query, conversation_history, location) urls = await infer_webpage_urls(query, conversation_history, location, user)
logger.info(f"Reading web pages at: {urls}") logger.info(f"Reading web pages at: {urls}")
if send_status_func: if send_status_func:

View file

@ -388,6 +388,7 @@ async def extract_references_and_questions(
conversation_log=meta_log, conversation_log=meta_log,
should_extract_questions=True, should_extract_questions=True,
location_data=location_data, location_data=location_data,
user=user,
max_prompt_size=conversation_config.max_prompt_size, max_prompt_size=conversation_config.max_prompt_size,
) )
elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI: elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
@ -402,7 +403,7 @@ async def extract_references_and_questions(
api_base_url=base_url, api_base_url=base_url,
conversation_log=meta_log, conversation_log=meta_log,
location_data=location_data, location_data=location_data,
max_tokens=conversation_config.max_prompt_size, user=user,
) )
elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC: elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC:
api_key = conversation_config.openai_config.api_key api_key = conversation_config.openai_config.api_key
@ -413,6 +414,7 @@ async def extract_references_and_questions(
api_key=api_key, api_key=api_key,
conversation_log=meta_log, conversation_log=meta_log,
location_data=location_data, location_data=location_data,
user=user,
) )
# Collate search results as context for GPT # Collate search results as context for GPT

View file

@ -792,7 +792,7 @@ async def chat(
if ConversationCommand.Online in conversation_commands: if ConversationCommand.Online in conversation_commands:
try: try:
async for result in search_online( async for result in search_online(
defiltered_query, meta_log, location, partial(send_event, ChatEvent.STATUS), custom_filters defiltered_query, meta_log, location, user, partial(send_event, ChatEvent.STATUS), custom_filters
): ):
if isinstance(result, dict) and ChatEvent.STATUS in result: if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS] yield result[ChatEvent.STATUS]
@ -809,7 +809,7 @@ async def chat(
if ConversationCommand.Webpage in conversation_commands: if ConversationCommand.Webpage in conversation_commands:
try: try:
async for result in read_webpages( async for result in read_webpages(
defiltered_query, meta_log, location, partial(send_event, ChatEvent.STATUS) defiltered_query, meta_log, location, user, partial(send_event, ChatEvent.STATUS)
): ):
if isinstance(result, dict) and ChatEvent.STATUS in result: if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS] yield result[ChatEvent.STATUS]

View file

@ -340,11 +340,14 @@ async def aget_relevant_output_modes(query: str, conversation_history: dict, is_
return ConversationCommand.Text return ConversationCommand.Text
async def infer_webpage_urls(q: str, conversation_history: dict, location_data: LocationData) -> List[str]: async def infer_webpage_urls(
q: str, conversation_history: dict, location_data: LocationData, user: KhojUser
) -> List[str]:
""" """
Infer webpage links from the given query Infer webpage links from the given query
""" """
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown" location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown"
username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else ""
chat_history = construct_chat_history(conversation_history) chat_history = construct_chat_history(conversation_history)
utc_date = datetime.utcnow().strftime("%Y-%m-%d") utc_date = datetime.utcnow().strftime("%Y-%m-%d")
@ -353,6 +356,7 @@ async def infer_webpage_urls(q: str, conversation_history: dict, location_data:
query=q, query=q,
chat_history=chat_history, chat_history=chat_history,
location=location, location=location,
username=username,
) )
with timer("Chat actor: Infer webpage urls to read", logger): with timer("Chat actor: Infer webpage urls to read", logger):
@ -370,11 +374,14 @@ async def infer_webpage_urls(q: str, conversation_history: dict, location_data:
raise ValueError(f"Invalid list of urls: {response}") raise ValueError(f"Invalid list of urls: {response}")
async def generate_online_subqueries(q: str, conversation_history: dict, location_data: LocationData) -> List[str]: async def generate_online_subqueries(
q: str, conversation_history: dict, location_data: LocationData, user: KhojUser
) -> List[str]:
""" """
Generate subqueries from the given query Generate subqueries from the given query
""" """
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown" location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown"
username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else ""
chat_history = construct_chat_history(conversation_history) chat_history = construct_chat_history(conversation_history)
utc_date = datetime.utcnow().strftime("%Y-%m-%d") utc_date = datetime.utcnow().strftime("%Y-%m-%d")
@ -383,6 +390,7 @@ async def generate_online_subqueries(q: str, conversation_history: dict, locatio
query=q, query=q,
chat_history=chat_history, chat_history=chat_history,
location=location, location=location,
username=username,
) )
with timer("Chat actor: Generate online search subqueries", logger): with timer("Chat actor: Generate online search subqueries", logger):

View file

@ -17,6 +17,7 @@ from khoj.routers.helpers import (
) )
from khoj.utils.helpers import ConversationCommand from khoj.utils.helpers import ConversationCommand
from khoj.utils.rawconfig import LocationData from khoj.utils.rawconfig import LocationData
from tests.conftest import default_user2
# Initialize variables for tests # Initialize variables for tests
api_key = os.getenv("OPENAI_API_KEY") api_key = os.getenv("OPENAI_API_KEY")
@ -412,18 +413,23 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content():
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(reason="Chat actor not consistently capable of asking for clarification yet.")
@pytest.mark.chatquality @pytest.mark.chatquality
def test_ask_for_clarification_if_not_enough_context_in_question(): def test_ask_for_clarification_if_not_enough_context_in_question():
"Chat actor should ask for clarification if question cannot be answered unambiguously with the provided context" "Chat actor should ask for clarification if question cannot be answered unambiguously with the provided context"
# Arrange # Arrange
context = [ context = [
f"""# Ramya {
My sister, Ramya, is married to Kali Devi. They have 2 kids, Ravi and Rani.""", "compiled": f"""# Ramya
f"""# Fang My sister, Ramya, is married to Kali Devi. They have 2 kids, Ravi and Rani."""
My sister, Fang Liu is married to Xi Li. They have 1 kid, Xiao Li.""", },
f"""# Aiyla {
My sister, Aiyla is married to Tolga. They have 3 kids, Yildiz, Ali and Ahmet.""", "compiled": f"""# Fang
My sister, Fang Liu is married to Xi Li. They have 1 kid, Xiao Li."""
},
{
"compiled": f"""# Aiyla
My sister, Aiyla is married to Tolga. They have 3 kids, Yildiz, Ali and Ahmet."""
},
] ]
# Act # Act
@ -481,12 +487,12 @@ def test_agent_prompt_should_be_used(openai_agent):
@pytest.mark.anyio @pytest.mark.anyio
@pytest.mark.django_db(transaction=True) @pytest.mark.django_db(transaction=True)
@freeze_time("2024-04-04", ignore=["transformers"]) @freeze_time("2024-04-04", ignore=["transformers"])
async def test_websearch_with_operators(chat_client): async def test_websearch_with_operators(chat_client, default_user2):
# Arrange # Arrange
user_query = "Share popular posts on r/worldnews this month" user_query = "Share popular posts on r/worldnews this month"
# Act # Act
responses = await generate_online_subqueries(user_query, {}, None) responses = await generate_online_subqueries(user_query, {}, None, default_user2)
# Assert # Assert
assert any( assert any(
@ -501,12 +507,12 @@ async def test_websearch_with_operators(chat_client):
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
@pytest.mark.anyio @pytest.mark.anyio
@pytest.mark.django_db(transaction=True) @pytest.mark.django_db(transaction=True)
async def test_websearch_khoj_website_for_info_about_khoj(chat_client): async def test_websearch_khoj_website_for_info_about_khoj(chat_client, default_user2):
# Arrange # Arrange
user_query = "Do you support image search?" user_query = "Do you support image search?"
# Act # Act
responses = await generate_online_subqueries(user_query, {}, None) responses = await generate_online_subqueries(user_query, {}, None, default_user2)
# Assert # Assert
assert any( assert any(
@ -558,12 +564,12 @@ async def test_select_data_sources_actor_chooses_to_search_notes(
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
@pytest.mark.anyio @pytest.mark.anyio
@pytest.mark.django_db(transaction=True) @pytest.mark.django_db(transaction=True)
async def test_infer_webpage_urls_actor_extracts_correct_links(chat_client): async def test_infer_webpage_urls_actor_extracts_correct_links(chat_client, default_user2):
# Arrange # Arrange
user_query = "Summarize the wikipedia page on the history of the internet" user_query = "Summarize the wikipedia page on the history of the internet"
# Act # Act
urls = await infer_webpage_urls(user_query, {}, None) urls = await infer_webpage_urls(user_query, {}, None, default_user2)
# Assert # Assert
assert "https://en.wikipedia.org/wiki/History_of_the_Internet" in urls assert "https://en.wikipedia.org/wiki/History_of_the_Internet" in urls
@ -667,6 +673,10 @@ def populate_chat_history(message_list):
conversation_log["chat"] += message_to_log( conversation_log["chat"] += message_to_log(
user_message, user_message,
gpt_message, gpt_message,
{"context": context, "intent": {"query": user_message, "inferred-queries": f'["{user_message}"]'}}, khoj_message_metadata={
"context": context,
"intent": {"query": user_message, "inferred-queries": f'["{user_message}"]'},
},
conversation_log=[],
) )
return conversation_log return conversation_log