Working prototype of meta-level chain of reasoning and execution

- Create a more dynamic reasoning agent that can evaluate information and understand what it doesn't know, making moves to get that information
- Lots of hacks and code that needs to be reversed later on before submission
This commit is contained in:
sabaimran 2024-10-09 15:54:25 -07:00
parent 00546c1a63
commit f867d5ed72
6 changed files with 906 additions and 531 deletions

View file

@ -485,6 +485,47 @@ Khoj:
""".strip()
)
plan_function_execution = PromptTemplate.from_template(
"""
You are Khoj, an extremely smart and helpful search assistant.
{personality_context}
- You have access to a variety of data sources to help you answer the user's question
- You can use the data sources listed below to collect more relevant information, one at a time
- You are given multiple iterations to with these data sources to answer the user's question
- You are provided with additional context. If you have enough context to answer the question, then exit execution
If you already know the answer to the question, return an empty response, e.g., {{}}.
Which of the data sources listed below you would use to answer the user's question? You **only** have access to the following data sources:
{tools}
Now it's your turn to pick the data sources you would like to use to answer the user's question. Provide the data source and associated query in a JSON object. Do not say anything else.
Previous Iterations:
{previous_iterations}
Response format:
{{"data_source": "<tool_name>", "query": "<your_new_query>"}}
Chat History:
{chat_history}
Q: {query}
Khoj:
""".strip()
)
previous_iteration = PromptTemplate.from_template(
"""
data_source: {data_source}
query: {query}
context: {context}
onlineContext: {onlineContext}
---
""".strip()
)
pick_relevant_information_collection_tools = PromptTemplate.from_template(
"""
You are Khoj, an extremely smart and helpful search assistant.

View file

@ -355,9 +355,10 @@ async def extract_references_and_questions(
agent_has_entries = await sync_to_async(EntryAdapters.agent_has_entries)(agent=agent)
if (
not ConversationCommand.Notes in conversation_commands
and not ConversationCommand.Default in conversation_commands
and not agent_has_entries
# not ConversationCommand.Notes in conversation_commands
# and not ConversationCommand.Default in conversation_commands
# and not agent_has_entries
True
):
yield compiled_references, inferred_queries, q
return

File diff suppressed because it is too large Load diff

View file

@ -14,6 +14,7 @@ from typing import (
Annotated,
Any,
AsyncGenerator,
Callable,
Dict,
Iterator,
List,
@ -39,6 +40,7 @@ from khoj.database.adapters import (
AutomationAdapters,
ConversationAdapters,
EntryAdapters,
FileObjectAdapters,
create_khoj_token,
get_khoj_tokens,
get_user_name,
@ -614,6 +616,58 @@ async def extract_relevant_summary(
return response.strip()
async def generate_summary_from_files(
q: str,
user: KhojUser,
file_filters: List[str],
meta_log: dict,
subscribed: bool,
uploaded_image_url: str = None,
agent: Agent = None,
send_status_func: Optional[Callable] = None,
send_response_func: Optional[Callable] = None,
):
try:
file_object = None
if await EntryAdapters.aagent_has_entries(agent):
file_names = await EntryAdapters.aget_agent_entry_filepaths(agent)
if len(file_names) > 0:
file_object = await FileObjectAdapters.async_get_file_objects_by_name(None, file_names[0], agent)
if len(file_filters) > 0:
file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0])
if len(file_object) == 0:
response_log = (
"Sorry, I couldn't find the full text of this file. Please re-upload the document and try again."
)
async for result in send_response_func(response_log):
yield result
return
contextual_data = " ".join([file.raw_text for file in file_object])
if not q:
q = "Create a general summary of the file"
async for result in send_status_func(f"**Constructing Summary Using:** {file_object[0].file_name}"):
yield result
response = await extract_relevant_summary(
q,
contextual_data,
conversation_history=meta_log,
subscribed=subscribed,
uploaded_image_url=uploaded_image_url,
agent=agent,
)
response_log = str(response)
async for result in send_response_func(response_log):
yield result
except Exception as e:
response_log = "Error summarizing file. Please try again, or contact support."
logger.error(f"Error summarizing file for {user.email}: {e}", exc_info=True)
async for result in send_response_func(response_log):
yield result
async def generate_better_image_prompt(
q: str,
conversation_history: str,
@ -893,6 +947,7 @@ def generate_chat_response(
q: str,
meta_log: dict,
conversation: Conversation,
meta_research: str = "",
compiled_references: List[Dict] = [],
online_results: Dict[str, Dict] = {},
inferred_queries: List[str] = [],
@ -910,6 +965,9 @@ def generate_chat_response(
metadata = {}
agent = AgentAdapters.get_conversation_agent_by_id(conversation.agent.id) if conversation.agent else None
query_to_run = q
if meta_research:
query_to_run = f"AI Research: {meta_research} {q}"
try:
partial_completion = partial(
save_to_conversation_log,
@ -937,7 +995,7 @@ def generate_chat_response(
chat_response = converse_offline(
references=compiled_references,
online_results=online_results,
user_query=q,
user_query=query_to_run,
loaded_model=loaded_model,
conversation_log=meta_log,
completion_func=partial_completion,
@ -956,7 +1014,7 @@ def generate_chat_response(
chat_model = conversation_config.chat_model
chat_response = converse(
compiled_references,
q,
query_to_run,
image_url=uploaded_image_url,
online_results=online_results,
conversation_log=meta_log,
@ -977,7 +1035,7 @@ def generate_chat_response(
api_key = conversation_config.openai_config.api_key
chat_response = converse_anthropic(
compiled_references,
q,
query_to_run,
online_results,
meta_log,
model=conversation_config.chat_model,
@ -994,7 +1052,7 @@ def generate_chat_response(
api_key = conversation_config.openai_config.api_key
chat_response = converse_gemini(
compiled_references,
q,
query_to_run,
online_results,
meta_log,
model=conversation_config.chat_model,

View file

@ -0,0 +1,261 @@
import json
import logging
from typing import Any, Callable, Dict, List, Optional
from fastapi import Request
from khoj.database.adapters import EntryAdapters
from khoj.database.models import Agent, KhojUser
from khoj.processor.conversation import prompts
from khoj.processor.conversation.utils import remove_json_codeblock
from khoj.processor.tools.online_search import read_webpages, search_online
from khoj.routers.api import extract_references_and_questions
from khoj.routers.helpers import (
ChatEvent,
construct_chat_history,
generate_summary_from_files,
send_message_to_model_wrapper,
)
from khoj.utils.helpers import (
ConversationCommand,
function_calling_description_for_llm,
timer,
)
from khoj.utils.rawconfig import LocationData
logger = logging.getLogger(__name__)
class InformationCollectionIteration:
def __init__(
self, data_source: str, query: str, context: str = None, onlineContext: str = None, result: Any = None
):
self.data_source = data_source
self.query = query
self.context = context
self.onlineContext = onlineContext
async def apick_next_tool(
query: str,
conversation_history: dict,
subscribed: bool,
uploaded_image_url: str = None,
agent: Agent = None,
previous_iterations: List[InformationCollectionIteration] = None,
):
"""
Given a query, determine which of the available tools the agent should use in order to answer appropriately. One at a time, and it's able to use subsequent iterations to refine the answer.
"""
tool_options = dict()
tool_options_str = ""
agent_tools = agent.input_tools if agent else []
for tool, description in function_calling_description_for_llm.items():
tool_options[tool.value] = description
if len(agent_tools) == 0 or tool.value in agent_tools:
tool_options_str += f'- "{tool.value}": "{description}"\n'
chat_history = construct_chat_history(conversation_history)
previous_iterations_history = ""
for iteration in previous_iterations:
iteration_data = prompts.previous_iteration.format(
query=iteration.query,
data_source=iteration.data_source,
context=str(iteration.context),
onlineContext=str(iteration.onlineContext),
)
previous_iterations_history += iteration_data
if uploaded_image_url:
query = f"[placeholder for user attached image]\n{query}"
personality_context = (
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
)
function_planning_prompt = prompts.plan_function_execution.format(
query=query,
tools=tool_options_str,
chat_history=chat_history,
personality_context=personality_context,
previous_iterations=previous_iterations_history,
)
with timer("Chat actor: Infer information sources to refer", logger):
response = await send_message_to_model_wrapper(
function_planning_prompt,
response_type="json_object",
subscribed=subscribed,
)
try:
response = response.strip()
response = remove_json_codeblock(response)
response = json.loads(response)
suggested_data_source = response.get("data_source", None)
suggested_query = response.get("query", None)
return InformationCollectionIteration(
data_source=suggested_data_source,
query=suggested_query,
)
except Exception as e:
logger.error(f"Invalid response for determining relevant tools: {response}. {e}", exc_info=True)
return InformationCollectionIteration(
data_source=None,
query=None,
)
async def execute_information_collection(
request: Request,
user: KhojUser,
query: str,
conversation_id: str,
conversation_history: dict,
subscribed: bool,
uploaded_image_url: str = None,
agent: Agent = None,
send_status_func: Optional[Callable] = None,
location: LocationData = None,
file_filters: List[str] = [],
):
iteration = 0
MAX_ITERATIONS = 2
previous_iterations = []
while iteration < MAX_ITERATIONS:
online_results: Dict = dict()
compiled_references, inferred_queries, defiltered_query = [], [], None
this_iteration = await apick_next_tool(
query, conversation_history, subscribed, uploaded_image_url, agent, previous_iterations
)
if this_iteration.data_source == ConversationCommand.Notes:
## Extract Document References
compiled_references, inferred_queries, defiltered_query = [], [], None
async for result in extract_references_and_questions(
request,
conversation_history,
this_iteration.query,
7,
None,
conversation_id,
[ConversationCommand.Default],
location,
send_status_func,
uploaded_image_url=uploaded_image_url,
agent=agent,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
else:
compiled_references.extend(result[0])
inferred_queries.extend(result[1])
defiltered_query = result[2]
previous_iterations.append(
InformationCollectionIteration(
data_source=this_iteration.data_source,
query=this_iteration.query,
context=str(compiled_references),
)
)
elif this_iteration.data_source == ConversationCommand.Online:
async for result in search_online(
this_iteration.query,
conversation_history,
location,
user,
subscribed,
send_status_func,
[],
uploaded_image_url=uploaded_image_url,
agent=agent,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
else:
online_results = result
previous_iterations.append(
InformationCollectionIteration(
data_source=this_iteration.data_source,
query=this_iteration.query,
onlineContext=online_results,
)
)
elif this_iteration.data_source == ConversationCommand.Webpage:
async for result in read_webpages(
this_iteration.query,
conversation_history,
location,
user,
subscribed,
send_status_func,
uploaded_image_url=uploaded_image_url,
agent=agent,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
else:
direct_web_pages = result
webpages = []
for query in direct_web_pages:
if online_results.get(query):
online_results[query]["webpages"] = direct_web_pages[query]["webpages"]
else:
online_results[query] = {"webpages": direct_web_pages[query]["webpages"]}
for webpage in direct_web_pages[query]["webpages"]:
webpages.append(webpage["link"])
yield send_status_func(f"**Read web pages**: {webpages}")
previous_iterations.append(
InformationCollectionIteration(
data_source=this_iteration.data_source,
query=this_iteration.query,
onlineContext=online_results,
)
)
elif this_iteration.data_source == ConversationCommand.Summarize:
response_log = ""
agent_has_entries = await EntryAdapters.aagent_has_entries(agent)
if len(file_filters) == 0 and not agent_has_entries:
previous_iterations.append(
InformationCollectionIteration(
data_source=this_iteration.data_source,
query=this_iteration.query,
context="No files selected for summarization.",
)
)
elif len(file_filters) > 1 and not agent_has_entries:
response_log = "Only one file can be selected for summarization."
previous_iterations.append(
InformationCollectionIteration(
data_source=this_iteration.data_source,
query=this_iteration.query,
context=response_log,
)
)
else:
response_log = await generate_summary_from_files(
q=query,
user=user,
file_filters=file_filters,
meta_log=conversation_history,
subscribed=subscribed,
send_status_func=send_status_func,
)
else:
iteration = MAX_ITERATIONS
iteration += 1
for completed_iter in previous_iterations:
yield completed_iter

View file

@ -345,6 +345,13 @@ tool_descriptions_for_llm = {
ConversationCommand.Summarize: "To retrieve an answer that depends on the entire document or a large text.",
}
function_calling_description_for_llm = {
ConversationCommand.Notes: "Use this if you think the user's personal knowledge base contains relevant context.",
ConversationCommand.Online: "Use this if you think the there's important information on the internet related to the query.",
ConversationCommand.Webpage: "Use this if the user has provided a webpage URL or you are share of a webpage URL that will help you directly answer this query",
ConversationCommand.Summarize: "Use this if you want to retrieve an answer that depends on reading an entire corpus.",
}
mode_descriptions_for_llm = {
ConversationCommand.Image: "Use this if the user is requesting you to generate a picture based on their description.",
ConversationCommand.Automation: "Use this if the user is requesting a response at a scheduled date or time.",