mirror of
https://github.com/khoj-ai/khoj.git
synced 2025-02-17 08:04:21 +00:00
Pass user name to document and online search actors prompts
This should improve the quality of personal information extraction from document and online sources. The user name is only used when it is set
This commit is contained in:
parent
eb5af38f33
commit
a47a54f207
9 changed files with 42 additions and 11 deletions
|
@ -6,7 +6,7 @@ from typing import Dict, Optional
|
|||
|
||||
from langchain.schema import ChatMessage
|
||||
|
||||
from khoj.database.models import Agent
|
||||
from khoj.database.models import Agent, KhojUser
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.anthropic.utils import (
|
||||
anthropic_chat_completion_with_backoff,
|
||||
|
@ -26,12 +26,14 @@ def extract_questions_anthropic(
|
|||
api_key=None,
|
||||
temperature=0,
|
||||
location_data: LocationData = None,
|
||||
user: KhojUser = None,
|
||||
):
|
||||
"""
|
||||
Infer search queries to retrieve relevant notes to answer user query
|
||||
"""
|
||||
# Extract Past User Message and Inferred Questions from Conversation Log
|
||||
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown"
|
||||
username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else ""
|
||||
|
||||
# Extract Past User Message and Inferred Questions from Conversation Log
|
||||
chat_history = "".join(
|
||||
|
@ -55,6 +57,7 @@ def extract_questions_anthropic(
|
|||
current_new_year_date=current_new_year.strftime("%Y-%m-%d"),
|
||||
yesterday_date=(today - timedelta(days=1)).strftime("%Y-%m-%d"),
|
||||
location=location,
|
||||
username=username,
|
||||
)
|
||||
|
||||
prompt = prompts.extract_questions_anthropic_user_message.format(
|
||||
|
|
|
@ -7,7 +7,7 @@ from typing import Any, Iterator, List, Union
|
|||
from langchain.schema import ChatMessage
|
||||
from llama_cpp import Llama
|
||||
|
||||
from khoj.database.models import Agent
|
||||
from khoj.database.models import Agent, KhojUser
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.offline.utils import download_model
|
||||
from khoj.processor.conversation.utils import (
|
||||
|
@ -30,6 +30,7 @@ def extract_questions_offline(
|
|||
use_history: bool = True,
|
||||
should_extract_questions: bool = True,
|
||||
location_data: LocationData = None,
|
||||
user: KhojUser = None,
|
||||
max_prompt_size: int = None,
|
||||
) -> List[str]:
|
||||
"""
|
||||
|
@ -45,6 +46,7 @@ def extract_questions_offline(
|
|||
offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
|
||||
|
||||
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown"
|
||||
username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else ""
|
||||
|
||||
# Extract Past User Message and Inferred Questions from Conversation Log
|
||||
chat_history = ""
|
||||
|
@ -68,6 +70,7 @@ def extract_questions_offline(
|
|||
last_year=last_year,
|
||||
this_year=today.year,
|
||||
location=location,
|
||||
username=username,
|
||||
)
|
||||
|
||||
messages = generate_chatml_messages_with_context(
|
||||
|
|
|
@ -5,7 +5,7 @@ from typing import Dict, Optional
|
|||
|
||||
from langchain.schema import ChatMessage
|
||||
|
||||
from khoj.database.models import Agent
|
||||
from khoj.database.models import Agent, KhojUser
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.openai.utils import (
|
||||
chat_completion_with_backoff,
|
||||
|
@ -27,11 +27,13 @@ def extract_questions(
|
|||
temperature=0,
|
||||
max_tokens=100,
|
||||
location_data: LocationData = None,
|
||||
user: KhojUser = None,
|
||||
):
|
||||
"""
|
||||
Infer search queries to retrieve relevant notes to answer user query
|
||||
"""
|
||||
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown"
|
||||
username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else ""
|
||||
|
||||
# Extract Past User Message and Inferred Questions from Conversation Log
|
||||
chat_history = "".join(
|
||||
|
@ -59,6 +61,7 @@ def extract_questions(
|
|||
text=text,
|
||||
yesterday_date=(today - timedelta(days=1)).strftime("%Y-%m-%d"),
|
||||
location=location,
|
||||
username=username,
|
||||
)
|
||||
messages = [ChatMessage(content=prompt, role="user")]
|
||||
|
||||
|
|
|
@ -36,7 +36,7 @@ def completion_with_backoff(
|
|||
messages, model, temperature=0, openai_api_key=None, api_base_url=None, model_kwargs=None
|
||||
) -> str:
|
||||
client_key = f"{openai_api_key}--{api_base_url}"
|
||||
client: openai.OpenAI = openai_clients.get(client_key)
|
||||
client: openai.OpenAI | None = openai_clients.get(client_key)
|
||||
if not client:
|
||||
client = openai.OpenAI(
|
||||
api_key=openai_api_key,
|
||||
|
|
|
@ -212,6 +212,7 @@ Construct search queries to retrieve relevant information to answer the user's q
|
|||
|
||||
Current Date: {day_of_week}, {current_date}
|
||||
User's Location: {location}
|
||||
{username}
|
||||
|
||||
Examples:
|
||||
Q: How was my trip to Cambodia?
|
||||
|
@ -258,6 +259,7 @@ Construct search queries to retrieve relevant information to answer the user's q
|
|||
What searches will you perform to answer the users question? Respond with search queries as list of strings in a JSON object.
|
||||
Current Date: {day_of_week}, {current_date}
|
||||
User's Location: {location}
|
||||
{username}
|
||||
|
||||
Q: How was my trip to Cambodia?
|
||||
Khoj: {{"queries": ["How was my trip to Cambodia?"]}}
|
||||
|
@ -310,6 +312,7 @@ What searches will you perform to answer the users question? Respond with a JSON
|
|||
|
||||
Current Date: {day_of_week}, {current_date}
|
||||
User's Location: {location}
|
||||
{username}
|
||||
|
||||
Here are some examples of how you can construct search queries to answer the user's question:
|
||||
|
||||
|
@ -525,6 +528,7 @@ Which webpages will you need to read to answer the user's question?
|
|||
Provide web page links as a list of strings in a JSON object.
|
||||
Current Date: {current_date}
|
||||
User's Location: {location}
|
||||
{username}
|
||||
|
||||
Here are some examples:
|
||||
History:
|
||||
|
@ -571,6 +575,7 @@ What Google searches, if any, will you need to perform to answer the user's ques
|
|||
Provide search queries as a list of strings in a JSON object. Do not wrap the json in a codeblock.
|
||||
Current Date: {current_date}
|
||||
User's Location: {location}
|
||||
{username}
|
||||
|
||||
Here are some examples:
|
||||
History:
|
||||
|
|
|
@ -10,6 +10,7 @@ import aiohttp
|
|||
from bs4 import BeautifulSoup
|
||||
from markdownify import markdownify
|
||||
|
||||
from khoj.database.models import KhojUser
|
||||
from khoj.routers.helpers import (
|
||||
ChatEvent,
|
||||
extract_relevant_info,
|
||||
|
@ -51,6 +52,7 @@ async def search_online(
|
|||
query: str,
|
||||
conversation_history: dict,
|
||||
location: LocationData,
|
||||
user: KhojUser,
|
||||
send_status_func: Optional[Callable] = None,
|
||||
custom_filters: List[str] = [],
|
||||
):
|
||||
|
@ -61,7 +63,7 @@ async def search_online(
|
|||
return
|
||||
|
||||
# Breakdown the query into subqueries to get the correct answer
|
||||
subqueries = await generate_online_subqueries(query, conversation_history, location)
|
||||
subqueries = await generate_online_subqueries(query, conversation_history, location, user)
|
||||
response_dict = {}
|
||||
|
||||
if subqueries:
|
||||
|
@ -126,14 +128,18 @@ async def search_with_google(query: str) -> Tuple[str, Dict[str, List[Dict]]]:
|
|||
|
||||
|
||||
async def read_webpages(
|
||||
query: str, conversation_history: dict, location: LocationData, send_status_func: Optional[Callable] = None
|
||||
query: str,
|
||||
conversation_history: dict,
|
||||
location: LocationData,
|
||||
user: KhojUser,
|
||||
send_status_func: Optional[Callable] = None,
|
||||
):
|
||||
"Infer web pages to read from the query and extract relevant information from them"
|
||||
logger.info(f"Inferring web pages to read")
|
||||
if send_status_func:
|
||||
async for event in send_status_func(f"**🧐 Inferring web pages to read**"):
|
||||
yield {ChatEvent.STATUS: event}
|
||||
urls = await infer_webpage_urls(query, conversation_history, location)
|
||||
urls = await infer_webpage_urls(query, conversation_history, location, user)
|
||||
|
||||
logger.info(f"Reading web pages at: {urls}")
|
||||
if send_status_func:
|
||||
|
|
|
@ -343,6 +343,7 @@ async def extract_references_and_questions(
|
|||
conversation_log=meta_log,
|
||||
should_extract_questions=True,
|
||||
location_data=location_data,
|
||||
user=user,
|
||||
max_prompt_size=conversation_config.max_prompt_size,
|
||||
)
|
||||
elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
|
||||
|
@ -357,6 +358,7 @@ async def extract_references_and_questions(
|
|||
api_base_url=base_url,
|
||||
conversation_log=meta_log,
|
||||
location_data=location_data,
|
||||
user=user,
|
||||
max_tokens=conversation_config.max_prompt_size,
|
||||
)
|
||||
elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC:
|
||||
|
@ -368,6 +370,7 @@ async def extract_references_and_questions(
|
|||
api_key=api_key,
|
||||
conversation_log=meta_log,
|
||||
location_data=location_data,
|
||||
user=user,
|
||||
)
|
||||
|
||||
# Collate search results as context for GPT
|
||||
|
|
|
@ -800,7 +800,7 @@ async def chat(
|
|||
if ConversationCommand.Online in conversation_commands:
|
||||
try:
|
||||
async for result in search_online(
|
||||
defiltered_query, meta_log, location, partial(send_event, ChatEvent.STATUS), custom_filters
|
||||
defiltered_query, meta_log, location, user, partial(send_event, ChatEvent.STATUS), custom_filters
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
yield result[ChatEvent.STATUS]
|
||||
|
@ -817,7 +817,7 @@ async def chat(
|
|||
if ConversationCommand.Webpage in conversation_commands:
|
||||
try:
|
||||
async for result in read_webpages(
|
||||
defiltered_query, meta_log, location, partial(send_event, ChatEvent.STATUS)
|
||||
defiltered_query, meta_log, location, user, partial(send_event, ChatEvent.STATUS)
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
yield result[ChatEvent.STATUS]
|
||||
|
|
|
@ -315,11 +315,14 @@ async def aget_relevant_output_modes(query: str, conversation_history: dict, is_
|
|||
return ConversationCommand.Text
|
||||
|
||||
|
||||
async def infer_webpage_urls(q: str, conversation_history: dict, location_data: LocationData) -> List[str]:
|
||||
async def infer_webpage_urls(
|
||||
q: str, conversation_history: dict, location_data: LocationData, user: KhojUser
|
||||
) -> List[str]:
|
||||
"""
|
||||
Infer webpage links from the given query
|
||||
"""
|
||||
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown"
|
||||
username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else ""
|
||||
chat_history = construct_chat_history(conversation_history)
|
||||
|
||||
utc_date = datetime.utcnow().strftime("%Y-%m-%d")
|
||||
|
@ -328,6 +331,7 @@ async def infer_webpage_urls(q: str, conversation_history: dict, location_data:
|
|||
query=q,
|
||||
chat_history=chat_history,
|
||||
location=location,
|
||||
username=username,
|
||||
)
|
||||
|
||||
with timer("Chat actor: Infer webpage urls to read", logger):
|
||||
|
@ -345,11 +349,14 @@ async def infer_webpage_urls(q: str, conversation_history: dict, location_data:
|
|||
raise ValueError(f"Invalid list of urls: {response}")
|
||||
|
||||
|
||||
async def generate_online_subqueries(q: str, conversation_history: dict, location_data: LocationData) -> List[str]:
|
||||
async def generate_online_subqueries(
|
||||
q: str, conversation_history: dict, location_data: LocationData, user: KhojUser
|
||||
) -> List[str]:
|
||||
"""
|
||||
Generate subqueries from the given query
|
||||
"""
|
||||
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown"
|
||||
username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else ""
|
||||
chat_history = construct_chat_history(conversation_history)
|
||||
|
||||
utc_date = datetime.utcnow().strftime("%Y-%m-%d")
|
||||
|
@ -358,6 +365,7 @@ async def generate_online_subqueries(q: str, conversation_history: dict, locatio
|
|||
query=q,
|
||||
chat_history=chat_history,
|
||||
location=location,
|
||||
username=username,
|
||||
)
|
||||
|
||||
with timer("Chat actor: Generate online search subqueries", logger):
|
||||
|
|
Loading…
Add table
Reference in a new issue