diff --git a/src/khoj/processor/conversation/anthropic/anthropic_chat.py b/src/khoj/processor/conversation/anthropic/anthropic_chat.py index d5778885..ef8539b3 100644 --- a/src/khoj/processor/conversation/anthropic/anthropic_chat.py +++ b/src/khoj/processor/conversation/anthropic/anthropic_chat.py @@ -6,7 +6,7 @@ from typing import Dict, Optional from langchain.schema import ChatMessage -from khoj.database.models import Agent +from khoj.database.models import Agent, KhojUser from khoj.processor.conversation import prompts from khoj.processor.conversation.anthropic.utils import ( anthropic_chat_completion_with_backoff, @@ -24,14 +24,16 @@ def extract_questions_anthropic( model: Optional[str] = "claude-instant-1.2", conversation_log={}, api_key=None, - temperature=0, + temperature=0.7, location_data: LocationData = None, + user: KhojUser = None, ): """ Infer search queries to retrieve relevant notes to answer user query """ # Extract Past User Message and Inferred Questions from Conversation Log location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown" + username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else "" # Extract Past User Message and Inferred Questions from Conversation Log chat_history = "".join( @@ -50,11 +52,13 @@ def extract_questions_anthropic( system_prompt = prompts.extract_questions_anthropic_system_prompt.format( current_date=today.strftime("%Y-%m-%d"), day_of_week=today.strftime("%A"), + current_month=today.strftime("%Y-%m"), last_new_year=last_new_year.strftime("%Y"), last_new_year_date=last_new_year.strftime("%Y-%m-%d"), current_new_year_date=current_new_year.strftime("%Y-%m-%d"), yesterday_date=(today - timedelta(days=1)).strftime("%Y-%m-%d"), location=location, + username=username, ) prompt = prompts.extract_questions_anthropic_user_message.format( diff --git a/src/khoj/processor/conversation/offline/chat_model.py b/src/khoj/processor/conversation/offline/chat_model.py index 2da0c186..ec4c7367 100644 --- a/src/khoj/processor/conversation/offline/chat_model.py +++ b/src/khoj/processor/conversation/offline/chat_model.py @@ -7,7 +7,7 @@ from typing import Any, Iterator, List, Union from langchain.schema import ChatMessage from llama_cpp import Llama -from khoj.database.models import Agent +from khoj.database.models import Agent, KhojUser from khoj.processor.conversation import prompts from khoj.processor.conversation.offline.utils import download_model from khoj.processor.conversation.utils import ( @@ -30,7 +30,9 @@ def extract_questions_offline( use_history: bool = True, should_extract_questions: bool = True, location_data: LocationData = None, + user: KhojUser = None, max_prompt_size: int = None, + temperature: float = 0.7, ) -> List[str]: """ Infer search queries to retrieve relevant notes to answer user query @@ -45,6 +47,7 @@ def extract_questions_offline( offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size) location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown" + username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else "" # Extract Past User Message and Inferred Questions from Conversation Log chat_history = "" @@ -64,10 +67,12 @@ def extract_questions_offline( chat_history=chat_history, current_date=today.strftime("%Y-%m-%d"), day_of_week=today.strftime("%A"), + current_month=today.strftime("%Y-%m"), yesterday_date=yesterday, last_year=last_year, this_year=today.year, location=location, + username=username, ) messages = generate_chatml_messages_with_context( @@ -77,7 +82,11 @@ def extract_questions_offline( state.chat_lock.acquire() try: response = send_message_to_model_offline( - messages, loaded_model=offline_chat_model, model=model, max_prompt_size=max_prompt_size + messages, + loaded_model=offline_chat_model, + model=model, + max_prompt_size=max_prompt_size, + temperature=temperature, ) finally: state.chat_lock.release() @@ -229,6 +238,7 @@ def send_message_to_model_offline( messages: List[ChatMessage], loaded_model=None, model="NousResearch/Hermes-2-Pro-Mistral-7B-GGUF", + temperature: float = 0.2, streaming=False, stop=[], max_prompt_size: int = None, @@ -236,7 +246,9 @@ def send_message_to_model_offline( assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured" offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size) messages_dict = [{"role": message.role, "content": message.content} for message in messages] - response = offline_chat_model.create_chat_completion(messages_dict, stop=stop, stream=streaming) + response = offline_chat_model.create_chat_completion( + messages_dict, stop=stop, stream=streaming, temperature=temperature + ) if streaming: return response else: diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index f1608fba..5dad883b 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -5,7 +5,7 @@ from typing import Dict, Optional from langchain.schema import ChatMessage -from khoj.database.models import Agent +from khoj.database.models import Agent, KhojUser from khoj.processor.conversation import prompts from khoj.processor.conversation.openai.utils import ( chat_completion_with_backoff, @@ -24,14 +24,15 @@ def extract_questions( conversation_log={}, api_key=None, api_base_url=None, - temperature=0, - max_tokens=100, + temperature=0.7, location_data: LocationData = None, + user: KhojUser = None, ): """ Infer search queries to retrieve relevant notes to answer user query """ location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown" + username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else "" # Extract Past User Message and Inferred Questions from Conversation Log chat_history = "".join( @@ -50,6 +51,7 @@ def extract_questions( prompt = prompts.extract_questions.format( current_date=today.strftime("%Y-%m-%d"), day_of_week=today.strftime("%A"), + current_month=today.strftime("%Y-%m"), last_new_year=last_new_year.strftime("%Y"), last_new_year_date=last_new_year.strftime("%Y-%m-%d"), current_new_year_date=current_new_year.strftime("%Y-%m-%d"), @@ -59,6 +61,7 @@ def extract_questions( text=text, yesterday_date=(today - timedelta(days=1)).strftime("%Y-%m-%d"), location=location, + username=username, ) messages = [ChatMessage(content=prompt, role="user")] diff --git a/src/khoj/processor/conversation/openai/utils.py b/src/khoj/processor/conversation/openai/utils.py index bfc70482..314c3c1d 100644 --- a/src/khoj/processor/conversation/openai/utils.py +++ b/src/khoj/processor/conversation/openai/utils.py @@ -36,7 +36,7 @@ def completion_with_backoff( messages, model, temperature=0, openai_api_key=None, api_base_url=None, model_kwargs=None ) -> str: client_key = f"{openai_api_key}--{api_base_url}" - client: openai.OpenAI = openai_clients.get(client_key) + client: openai.OpenAI | None = openai_clients.get(client_key) if not client: client = openai.OpenAI( api_key=openai_api_key, diff --git a/src/khoj/processor/conversation/prompts.py b/src/khoj/processor/conversation/prompts.py index 289bafbc..6a8db9db 100644 --- a/src/khoj/processor/conversation/prompts.py +++ b/src/khoj/processor/conversation/prompts.py @@ -208,10 +208,12 @@ Construct search queries to retrieve relevant information to answer the user's q - Add as much context from the previous questions and answers as required into your search queries. - Break messages into multiple search queries when required to retrieve the relevant information. - Add date filters to your search queries from questions and answers when required to retrieve the relevant information. +- When asked a meta, vague or random questions, search for a variety of broad topics to answer the user's question. - Share relevant search queries as a JSON list of strings. Do not say anything else. Current Date: {day_of_week}, {current_date} User's Location: {location} +{username} Examples: Q: How was my trip to Cambodia? @@ -238,6 +240,9 @@ Khoj: ["What kind of plants do I have?", "What issues do my plants have?"] Q: Who all did I meet here yesterday? Khoj: ["Met in {location} on {yesterday_date} dt>='{yesterday_date}' dt<'{current_date}'"] +Q: Share some random, interesting experiences from this month +Khoj: ["Exciting travel adventures from {current_month}", "Fun social events dt>='{current_month}-01' dt<'{current_date}'", "Intense emotional experiences in {current_month}"] + Chat History: {chat_history} What searches will you perform to answer the following question, using the chat history as reference? Respond only with relevant search queries as a valid JSON list of strings. @@ -254,10 +259,12 @@ Construct search queries to retrieve relevant information to answer the user's q - Add as much context from the previous questions and answers as required into your search queries. - Break messages into multiple search queries when required to retrieve the relevant information. - Add date filters to your search queries from questions and answers when required to retrieve the relevant information. +- When asked a meta, vague or random questions, search for a variety of broad topics to answer the user's question. What searches will you perform to answer the users question? Respond with search queries as list of strings in a JSON object. Current Date: {day_of_week}, {current_date} User's Location: {location} +{username} Q: How was my trip to Cambodia? Khoj: {{"queries": ["How was my trip to Cambodia?"]}} @@ -279,6 +286,10 @@ Q: How many tennis balls fit in the back of a 2002 Honda Civic? Khoj: {{"queries": ["What is the size of a tennis ball?", "What is the trunk size of a 2002 Honda Civic?"]}} A: 1085 tennis balls will fit in the trunk of a Honda Civic +Q: Share some random, interesting experiences from this month +Khoj: {{"queries": ["Exciting travel adventures from {current_month}", "Fun social events dt>='{current_month}-01' dt<'{current_date}'", "Intense emotional experiences in {current_month}"]}} +A: You had a great time at the local beach with your friends, attended a music concert and had a deep conversation with your friend, Khalid. + Q: Is Bob older than Tom? Khoj: {{"queries": ["When was Bob born?", "What is Tom's age?"]}} A: Yes, Bob is older than Tom. As Bob was born on 1984-01-01 and Tom is 30 years old. @@ -305,11 +316,13 @@ Construct search queries to retrieve relevant information to answer the user's q - Add as much context from the previous questions and answers as required into your search queries. - Break messages into multiple search queries when required to retrieve the relevant information. - Add date filters to your search queries from questions and answers when required to retrieve the relevant information. +- When asked a meta, vague or random questions, search for a variety of broad topics to answer the user's question. What searches will you perform to answer the users question? Respond with a JSON object with the key "queries" mapping to a list of searches you would perform on the user's knowledge base. Just return the queries and nothing else. Current Date: {day_of_week}, {current_date} User's Location: {location} +{username} Here are some examples of how you can construct search queries to answer the user's question: @@ -328,6 +341,11 @@ A: I can help you live healthier and happier across work and personal life User: Who all did I meet here yesterday? Assistant: {{"queries": ["Met in {location} on {yesterday_date} dt>='{yesterday_date}' dt<'{current_date}'"]}} A: Yesterday's note mentions your visit to your local beach with Ram and Shyam. + +User: Share some random, interesting experiences from this month +Assistant: {{"queries": ["Exciting travel adventures from {current_month}", "Fun social events dt>='{current_month}-01' dt<'{current_date}'", "Intense emotional experiences in {current_month}"]}} +A: You had a great time at the local beach with your friends, attended a music concert and had a deep conversation with your friend, Khalid. + """.strip() ) @@ -525,6 +543,7 @@ Which webpages will you need to read to answer the user's question? Provide web page links as a list of strings in a JSON object. Current Date: {current_date} User's Location: {location} +{username} Here are some examples: History: @@ -571,6 +590,7 @@ What Google searches, if any, will you need to perform to answer the user's ques Provide search queries as a list of strings in a JSON object. Do not wrap the json in a codeblock. Current Date: {current_date} User's Location: {location} +{username} Here are some examples: History: diff --git a/src/khoj/processor/tools/online_search.py b/src/khoj/processor/tools/online_search.py index caf800b0..70e17630 100644 --- a/src/khoj/processor/tools/online_search.py +++ b/src/khoj/processor/tools/online_search.py @@ -10,6 +10,7 @@ import aiohttp from bs4 import BeautifulSoup from markdownify import markdownify +from khoj.database.models import KhojUser from khoj.routers.helpers import ( ChatEvent, extract_relevant_info, @@ -51,6 +52,7 @@ async def search_online( query: str, conversation_history: dict, location: LocationData, + user: KhojUser, send_status_func: Optional[Callable] = None, custom_filters: List[str] = [], ): @@ -61,7 +63,7 @@ async def search_online( return # Breakdown the query into subqueries to get the correct answer - subqueries = await generate_online_subqueries(query, conversation_history, location) + subqueries = await generate_online_subqueries(query, conversation_history, location, user) response_dict = {} if subqueries: @@ -126,14 +128,18 @@ async def search_with_google(query: str) -> Tuple[str, Dict[str, List[Dict]]]: async def read_webpages( - query: str, conversation_history: dict, location: LocationData, send_status_func: Optional[Callable] = None + query: str, + conversation_history: dict, + location: LocationData, + user: KhojUser, + send_status_func: Optional[Callable] = None, ): "Infer web pages to read from the query and extract relevant information from them" logger.info(f"Inferring web pages to read") if send_status_func: async for event in send_status_func(f"**Inferring web pages to read**"): yield {ChatEvent.STATUS: event} - urls = await infer_webpage_urls(query, conversation_history, location) + urls = await infer_webpage_urls(query, conversation_history, location, user) logger.info(f"Reading web pages at: {urls}") if send_status_func: diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 1c7697e5..953449d3 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -388,6 +388,7 @@ async def extract_references_and_questions( conversation_log=meta_log, should_extract_questions=True, location_data=location_data, + user=user, max_prompt_size=conversation_config.max_prompt_size, ) elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI: @@ -402,7 +403,7 @@ async def extract_references_and_questions( api_base_url=base_url, conversation_log=meta_log, location_data=location_data, - max_tokens=conversation_config.max_prompt_size, + user=user, ) elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC: api_key = conversation_config.openai_config.api_key @@ -413,6 +414,7 @@ async def extract_references_and_questions( api_key=api_key, conversation_log=meta_log, location_data=location_data, + user=user, ) # Collate search results as context for GPT diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index da0bd2d8..d515006c 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -792,7 +792,7 @@ async def chat( if ConversationCommand.Online in conversation_commands: try: async for result in search_online( - defiltered_query, meta_log, location, partial(send_event, ChatEvent.STATUS), custom_filters + defiltered_query, meta_log, location, user, partial(send_event, ChatEvent.STATUS), custom_filters ): if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS] @@ -809,7 +809,7 @@ async def chat( if ConversationCommand.Webpage in conversation_commands: try: async for result in read_webpages( - defiltered_query, meta_log, location, partial(send_event, ChatEvent.STATUS) + defiltered_query, meta_log, location, user, partial(send_event, ChatEvent.STATUS) ): if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS] diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 8d5a96ad..31b8d9b5 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -340,11 +340,14 @@ async def aget_relevant_output_modes(query: str, conversation_history: dict, is_ return ConversationCommand.Text -async def infer_webpage_urls(q: str, conversation_history: dict, location_data: LocationData) -> List[str]: +async def infer_webpage_urls( + q: str, conversation_history: dict, location_data: LocationData, user: KhojUser +) -> List[str]: """ Infer webpage links from the given query """ location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown" + username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else "" chat_history = construct_chat_history(conversation_history) utc_date = datetime.utcnow().strftime("%Y-%m-%d") @@ -353,6 +356,7 @@ async def infer_webpage_urls(q: str, conversation_history: dict, location_data: query=q, chat_history=chat_history, location=location, + username=username, ) with timer("Chat actor: Infer webpage urls to read", logger): @@ -370,11 +374,14 @@ async def infer_webpage_urls(q: str, conversation_history: dict, location_data: raise ValueError(f"Invalid list of urls: {response}") -async def generate_online_subqueries(q: str, conversation_history: dict, location_data: LocationData) -> List[str]: +async def generate_online_subqueries( + q: str, conversation_history: dict, location_data: LocationData, user: KhojUser +) -> List[str]: """ Generate subqueries from the given query """ location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown" + username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else "" chat_history = construct_chat_history(conversation_history) utc_date = datetime.utcnow().strftime("%Y-%m-%d") @@ -383,6 +390,7 @@ async def generate_online_subqueries(q: str, conversation_history: dict, locatio query=q, chat_history=chat_history, location=location, + username=username, ) with timer("Chat actor: Generate online search subqueries", logger): diff --git a/tests/test_openai_chat_actors.py b/tests/test_openai_chat_actors.py index ae2a8c55..83aee27f 100644 --- a/tests/test_openai_chat_actors.py +++ b/tests/test_openai_chat_actors.py @@ -17,6 +17,7 @@ from khoj.routers.helpers import ( ) from khoj.utils.helpers import ConversationCommand from khoj.utils.rawconfig import LocationData +from tests.conftest import default_user2 # Initialize variables for tests api_key = os.getenv("OPENAI_API_KEY") @@ -412,18 +413,23 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(): # ---------------------------------------------------------------------------------------------------- -@pytest.mark.xfail(reason="Chat actor not consistently capable of asking for clarification yet.") @pytest.mark.chatquality def test_ask_for_clarification_if_not_enough_context_in_question(): "Chat actor should ask for clarification if question cannot be answered unambiguously with the provided context" # Arrange context = [ - f"""# Ramya -My sister, Ramya, is married to Kali Devi. They have 2 kids, Ravi and Rani.""", - f"""# Fang -My sister, Fang Liu is married to Xi Li. They have 1 kid, Xiao Li.""", - f"""# Aiyla -My sister, Aiyla is married to Tolga. They have 3 kids, Yildiz, Ali and Ahmet.""", + { + "compiled": f"""# Ramya +My sister, Ramya, is married to Kali Devi. They have 2 kids, Ravi and Rani.""" + }, + { + "compiled": f"""# Fang +My sister, Fang Liu is married to Xi Li. They have 1 kid, Xiao Li.""" + }, + { + "compiled": f"""# Aiyla +My sister, Aiyla is married to Tolga. They have 3 kids, Yildiz, Ali and Ahmet.""" + }, ] # Act @@ -481,12 +487,12 @@ def test_agent_prompt_should_be_used(openai_agent): @pytest.mark.anyio @pytest.mark.django_db(transaction=True) @freeze_time("2024-04-04", ignore=["transformers"]) -async def test_websearch_with_operators(chat_client): +async def test_websearch_with_operators(chat_client, default_user2): # Arrange user_query = "Share popular posts on r/worldnews this month" # Act - responses = await generate_online_subqueries(user_query, {}, None) + responses = await generate_online_subqueries(user_query, {}, None, default_user2) # Assert assert any( @@ -501,12 +507,12 @@ async def test_websearch_with_operators(chat_client): # ---------------------------------------------------------------------------------------------------- @pytest.mark.anyio @pytest.mark.django_db(transaction=True) -async def test_websearch_khoj_website_for_info_about_khoj(chat_client): +async def test_websearch_khoj_website_for_info_about_khoj(chat_client, default_user2): # Arrange user_query = "Do you support image search?" # Act - responses = await generate_online_subqueries(user_query, {}, None) + responses = await generate_online_subqueries(user_query, {}, None, default_user2) # Assert assert any( @@ -558,12 +564,12 @@ async def test_select_data_sources_actor_chooses_to_search_notes( # ---------------------------------------------------------------------------------------------------- @pytest.mark.anyio @pytest.mark.django_db(transaction=True) -async def test_infer_webpage_urls_actor_extracts_correct_links(chat_client): +async def test_infer_webpage_urls_actor_extracts_correct_links(chat_client, default_user2): # Arrange user_query = "Summarize the wikipedia page on the history of the internet" # Act - urls = await infer_webpage_urls(user_query, {}, None) + urls = await infer_webpage_urls(user_query, {}, None, default_user2) # Assert assert "https://en.wikipedia.org/wiki/History_of_the_Internet" in urls @@ -667,6 +673,10 @@ def populate_chat_history(message_list): conversation_log["chat"] += message_to_log( user_message, gpt_message, - {"context": context, "intent": {"query": user_message, "inferred-queries": f'["{user_message}"]'}}, + khoj_message_metadata={ + "context": context, + "intent": {"query": user_message, "inferred-queries": f'["{user_message}"]'}, + }, + conversation_log=[], ) return conversation_log