Pass user name to document and online search actors prompts

This should improve the quality of personal information extraction
from document and online sources. The user name is only used when it
is set
This commit is contained in:
Debanjum Singh Solanky 2024-07-26 22:56:24 +05:30
parent eb5af38f33
commit a47a54f207
9 changed files with 42 additions and 11 deletions

View file

@ -6,7 +6,7 @@ from typing import Dict, Optional
from langchain.schema import ChatMessage
from khoj.database.models import Agent
from khoj.database.models import Agent, KhojUser
from khoj.processor.conversation import prompts
from khoj.processor.conversation.anthropic.utils import (
anthropic_chat_completion_with_backoff,
@ -26,12 +26,14 @@ def extract_questions_anthropic(
api_key=None,
temperature=0,
location_data: LocationData = None,
user: KhojUser = None,
):
"""
Infer search queries to retrieve relevant notes to answer user query
"""
# Extract Past User Message and Inferred Questions from Conversation Log
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown"
username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else ""
# Extract Past User Message and Inferred Questions from Conversation Log
chat_history = "".join(
@ -55,6 +57,7 @@ def extract_questions_anthropic(
current_new_year_date=current_new_year.strftime("%Y-%m-%d"),
yesterday_date=(today - timedelta(days=1)).strftime("%Y-%m-%d"),
location=location,
username=username,
)
prompt = prompts.extract_questions_anthropic_user_message.format(

View file

@ -7,7 +7,7 @@ from typing import Any, Iterator, List, Union
from langchain.schema import ChatMessage
from llama_cpp import Llama
from khoj.database.models import Agent
from khoj.database.models import Agent, KhojUser
from khoj.processor.conversation import prompts
from khoj.processor.conversation.offline.utils import download_model
from khoj.processor.conversation.utils import (
@ -30,6 +30,7 @@ def extract_questions_offline(
use_history: bool = True,
should_extract_questions: bool = True,
location_data: LocationData = None,
user: KhojUser = None,
max_prompt_size: int = None,
) -> List[str]:
"""
@ -45,6 +46,7 @@ def extract_questions_offline(
offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown"
username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else ""
# Extract Past User Message and Inferred Questions from Conversation Log
chat_history = ""
@ -68,6 +70,7 @@ def extract_questions_offline(
last_year=last_year,
this_year=today.year,
location=location,
username=username,
)
messages = generate_chatml_messages_with_context(

View file

@ -5,7 +5,7 @@ from typing import Dict, Optional
from langchain.schema import ChatMessage
from khoj.database.models import Agent
from khoj.database.models import Agent, KhojUser
from khoj.processor.conversation import prompts
from khoj.processor.conversation.openai.utils import (
chat_completion_with_backoff,
@ -27,11 +27,13 @@ def extract_questions(
temperature=0,
max_tokens=100,
location_data: LocationData = None,
user: KhojUser = None,
):
"""
Infer search queries to retrieve relevant notes to answer user query
"""
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown"
username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else ""
# Extract Past User Message and Inferred Questions from Conversation Log
chat_history = "".join(
@ -59,6 +61,7 @@ def extract_questions(
text=text,
yesterday_date=(today - timedelta(days=1)).strftime("%Y-%m-%d"),
location=location,
username=username,
)
messages = [ChatMessage(content=prompt, role="user")]

View file

@ -36,7 +36,7 @@ def completion_with_backoff(
messages, model, temperature=0, openai_api_key=None, api_base_url=None, model_kwargs=None
) -> str:
client_key = f"{openai_api_key}--{api_base_url}"
client: openai.OpenAI = openai_clients.get(client_key)
client: openai.OpenAI | None = openai_clients.get(client_key)
if not client:
client = openai.OpenAI(
api_key=openai_api_key,

View file

@ -212,6 +212,7 @@ Construct search queries to retrieve relevant information to answer the user's q
Current Date: {day_of_week}, {current_date}
User's Location: {location}
{username}
Examples:
Q: How was my trip to Cambodia?
@ -258,6 +259,7 @@ Construct search queries to retrieve relevant information to answer the user's q
What searches will you perform to answer the users question? Respond with search queries as list of strings in a JSON object.
Current Date: {day_of_week}, {current_date}
User's Location: {location}
{username}
Q: How was my trip to Cambodia?
Khoj: {{"queries": ["How was my trip to Cambodia?"]}}
@ -310,6 +312,7 @@ What searches will you perform to answer the users question? Respond with a JSON
Current Date: {day_of_week}, {current_date}
User's Location: {location}
{username}
Here are some examples of how you can construct search queries to answer the user's question:
@ -525,6 +528,7 @@ Which webpages will you need to read to answer the user's question?
Provide web page links as a list of strings in a JSON object.
Current Date: {current_date}
User's Location: {location}
{username}
Here are some examples:
History:
@ -571,6 +575,7 @@ What Google searches, if any, will you need to perform to answer the user's ques
Provide search queries as a list of strings in a JSON object. Do not wrap the json in a codeblock.
Current Date: {current_date}
User's Location: {location}
{username}
Here are some examples:
History:

View file

@ -10,6 +10,7 @@ import aiohttp
from bs4 import BeautifulSoup
from markdownify import markdownify
from khoj.database.models import KhojUser
from khoj.routers.helpers import (
ChatEvent,
extract_relevant_info,
@ -51,6 +52,7 @@ async def search_online(
query: str,
conversation_history: dict,
location: LocationData,
user: KhojUser,
send_status_func: Optional[Callable] = None,
custom_filters: List[str] = [],
):
@ -61,7 +63,7 @@ async def search_online(
return
# Breakdown the query into subqueries to get the correct answer
subqueries = await generate_online_subqueries(query, conversation_history, location)
subqueries = await generate_online_subqueries(query, conversation_history, location, user)
response_dict = {}
if subqueries:
@ -126,14 +128,18 @@ async def search_with_google(query: str) -> Tuple[str, Dict[str, List[Dict]]]:
async def read_webpages(
query: str, conversation_history: dict, location: LocationData, send_status_func: Optional[Callable] = None
query: str,
conversation_history: dict,
location: LocationData,
user: KhojUser,
send_status_func: Optional[Callable] = None,
):
"Infer web pages to read from the query and extract relevant information from them"
logger.info(f"Inferring web pages to read")
if send_status_func:
async for event in send_status_func(f"**🧐 Inferring web pages to read**"):
yield {ChatEvent.STATUS: event}
urls = await infer_webpage_urls(query, conversation_history, location)
urls = await infer_webpage_urls(query, conversation_history, location, user)
logger.info(f"Reading web pages at: {urls}")
if send_status_func:

View file

@ -343,6 +343,7 @@ async def extract_references_and_questions(
conversation_log=meta_log,
should_extract_questions=True,
location_data=location_data,
user=user,
max_prompt_size=conversation_config.max_prompt_size,
)
elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
@ -357,6 +358,7 @@ async def extract_references_and_questions(
api_base_url=base_url,
conversation_log=meta_log,
location_data=location_data,
user=user,
max_tokens=conversation_config.max_prompt_size,
)
elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC:
@ -368,6 +370,7 @@ async def extract_references_and_questions(
api_key=api_key,
conversation_log=meta_log,
location_data=location_data,
user=user,
)
# Collate search results as context for GPT

View file

@ -800,7 +800,7 @@ async def chat(
if ConversationCommand.Online in conversation_commands:
try:
async for result in search_online(
defiltered_query, meta_log, location, partial(send_event, ChatEvent.STATUS), custom_filters
defiltered_query, meta_log, location, user, partial(send_event, ChatEvent.STATUS), custom_filters
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
@ -817,7 +817,7 @@ async def chat(
if ConversationCommand.Webpage in conversation_commands:
try:
async for result in read_webpages(
defiltered_query, meta_log, location, partial(send_event, ChatEvent.STATUS)
defiltered_query, meta_log, location, user, partial(send_event, ChatEvent.STATUS)
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]

View file

@ -315,11 +315,14 @@ async def aget_relevant_output_modes(query: str, conversation_history: dict, is_
return ConversationCommand.Text
async def infer_webpage_urls(q: str, conversation_history: dict, location_data: LocationData) -> List[str]:
async def infer_webpage_urls(
q: str, conversation_history: dict, location_data: LocationData, user: KhojUser
) -> List[str]:
"""
Infer webpage links from the given query
"""
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown"
username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else ""
chat_history = construct_chat_history(conversation_history)
utc_date = datetime.utcnow().strftime("%Y-%m-%d")
@ -328,6 +331,7 @@ async def infer_webpage_urls(q: str, conversation_history: dict, location_data:
query=q,
chat_history=chat_history,
location=location,
username=username,
)
with timer("Chat actor: Infer webpage urls to read", logger):
@ -345,11 +349,14 @@ async def infer_webpage_urls(q: str, conversation_history: dict, location_data:
raise ValueError(f"Invalid list of urls: {response}")
async def generate_online_subqueries(q: str, conversation_history: dict, location_data: LocationData) -> List[str]:
async def generate_online_subqueries(
q: str, conversation_history: dict, location_data: LocationData, user: KhojUser
) -> List[str]:
"""
Generate subqueries from the given query
"""
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown"
username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else ""
chat_history = construct_chat_history(conversation_history)
utc_date = datetime.utcnow().strftime("%Y-%m-%d")
@ -358,6 +365,7 @@ async def generate_online_subqueries(q: str, conversation_history: dict, locatio
query=q,
chat_history=chat_history,
location=location,
username=username,
)
with timer("Chat actor: Generate online search subqueries", logger):