mirror of
https://github.com/khoj-ai/khoj.git
synced 2025-02-17 08:04:21 +00:00
Working prototype of meta-level chain of reasoning and execution
- Create a more dynamic reasoning agent that can evaluate information and understand what it doesn't know, making moves to get that information - Lots of hacks and code that needs to be reversed later on before submission
This commit is contained in:
parent
00546c1a63
commit
f867d5ed72
6 changed files with 906 additions and 531 deletions
|
@ -485,6 +485,47 @@ Khoj:
|
|||
""".strip()
|
||||
)
|
||||
|
||||
plan_function_execution = PromptTemplate.from_template(
|
||||
"""
|
||||
You are Khoj, an extremely smart and helpful search assistant.
|
||||
{personality_context}
|
||||
- You have access to a variety of data sources to help you answer the user's question
|
||||
- You can use the data sources listed below to collect more relevant information, one at a time
|
||||
- You are given multiple iterations to with these data sources to answer the user's question
|
||||
- You are provided with additional context. If you have enough context to answer the question, then exit execution
|
||||
|
||||
If you already know the answer to the question, return an empty response, e.g., {{}}.
|
||||
|
||||
Which of the data sources listed below you would use to answer the user's question? You **only** have access to the following data sources:
|
||||
|
||||
{tools}
|
||||
|
||||
Now it's your turn to pick the data sources you would like to use to answer the user's question. Provide the data source and associated query in a JSON object. Do not say anything else.
|
||||
|
||||
Previous Iterations:
|
||||
{previous_iterations}
|
||||
|
||||
Response format:
|
||||
{{"data_source": "<tool_name>", "query": "<your_new_query>"}}
|
||||
|
||||
Chat History:
|
||||
{chat_history}
|
||||
|
||||
Q: {query}
|
||||
Khoj:
|
||||
""".strip()
|
||||
)
|
||||
|
||||
previous_iteration = PromptTemplate.from_template(
|
||||
"""
|
||||
data_source: {data_source}
|
||||
query: {query}
|
||||
context: {context}
|
||||
onlineContext: {onlineContext}
|
||||
---
|
||||
""".strip()
|
||||
)
|
||||
|
||||
pick_relevant_information_collection_tools = PromptTemplate.from_template(
|
||||
"""
|
||||
You are Khoj, an extremely smart and helpful search assistant.
|
||||
|
|
|
@ -355,9 +355,10 @@ async def extract_references_and_questions(
|
|||
agent_has_entries = await sync_to_async(EntryAdapters.agent_has_entries)(agent=agent)
|
||||
|
||||
if (
|
||||
not ConversationCommand.Notes in conversation_commands
|
||||
and not ConversationCommand.Default in conversation_commands
|
||||
and not agent_has_entries
|
||||
# not ConversationCommand.Notes in conversation_commands
|
||||
# and not ConversationCommand.Default in conversation_commands
|
||||
# and not agent_has_entries
|
||||
True
|
||||
):
|
||||
yield compiled_references, inferred_queries, q
|
||||
return
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -14,6 +14,7 @@ from typing import (
|
|||
Annotated,
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
|
@ -39,6 +40,7 @@ from khoj.database.adapters import (
|
|||
AutomationAdapters,
|
||||
ConversationAdapters,
|
||||
EntryAdapters,
|
||||
FileObjectAdapters,
|
||||
create_khoj_token,
|
||||
get_khoj_tokens,
|
||||
get_user_name,
|
||||
|
@ -614,6 +616,58 @@ async def extract_relevant_summary(
|
|||
return response.strip()
|
||||
|
||||
|
||||
async def generate_summary_from_files(
|
||||
q: str,
|
||||
user: KhojUser,
|
||||
file_filters: List[str],
|
||||
meta_log: dict,
|
||||
subscribed: bool,
|
||||
uploaded_image_url: str = None,
|
||||
agent: Agent = None,
|
||||
send_status_func: Optional[Callable] = None,
|
||||
send_response_func: Optional[Callable] = None,
|
||||
):
|
||||
try:
|
||||
file_object = None
|
||||
if await EntryAdapters.aagent_has_entries(agent):
|
||||
file_names = await EntryAdapters.aget_agent_entry_filepaths(agent)
|
||||
if len(file_names) > 0:
|
||||
file_object = await FileObjectAdapters.async_get_file_objects_by_name(None, file_names[0], agent)
|
||||
|
||||
if len(file_filters) > 0:
|
||||
file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0])
|
||||
|
||||
if len(file_object) == 0:
|
||||
response_log = (
|
||||
"Sorry, I couldn't find the full text of this file. Please re-upload the document and try again."
|
||||
)
|
||||
async for result in send_response_func(response_log):
|
||||
yield result
|
||||
return
|
||||
contextual_data = " ".join([file.raw_text for file in file_object])
|
||||
if not q:
|
||||
q = "Create a general summary of the file"
|
||||
async for result in send_status_func(f"**Constructing Summary Using:** {file_object[0].file_name}"):
|
||||
yield result
|
||||
|
||||
response = await extract_relevant_summary(
|
||||
q,
|
||||
contextual_data,
|
||||
conversation_history=meta_log,
|
||||
subscribed=subscribed,
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
agent=agent,
|
||||
)
|
||||
response_log = str(response)
|
||||
async for result in send_response_func(response_log):
|
||||
yield result
|
||||
except Exception as e:
|
||||
response_log = "Error summarizing file. Please try again, or contact support."
|
||||
logger.error(f"Error summarizing file for {user.email}: {e}", exc_info=True)
|
||||
async for result in send_response_func(response_log):
|
||||
yield result
|
||||
|
||||
|
||||
async def generate_better_image_prompt(
|
||||
q: str,
|
||||
conversation_history: str,
|
||||
|
@ -893,6 +947,7 @@ def generate_chat_response(
|
|||
q: str,
|
||||
meta_log: dict,
|
||||
conversation: Conversation,
|
||||
meta_research: str = "",
|
||||
compiled_references: List[Dict] = [],
|
||||
online_results: Dict[str, Dict] = {},
|
||||
inferred_queries: List[str] = [],
|
||||
|
@ -910,6 +965,9 @@ def generate_chat_response(
|
|||
|
||||
metadata = {}
|
||||
agent = AgentAdapters.get_conversation_agent_by_id(conversation.agent.id) if conversation.agent else None
|
||||
query_to_run = q
|
||||
if meta_research:
|
||||
query_to_run = f"AI Research: {meta_research} {q}"
|
||||
try:
|
||||
partial_completion = partial(
|
||||
save_to_conversation_log,
|
||||
|
@ -937,7 +995,7 @@ def generate_chat_response(
|
|||
chat_response = converse_offline(
|
||||
references=compiled_references,
|
||||
online_results=online_results,
|
||||
user_query=q,
|
||||
user_query=query_to_run,
|
||||
loaded_model=loaded_model,
|
||||
conversation_log=meta_log,
|
||||
completion_func=partial_completion,
|
||||
|
@ -956,7 +1014,7 @@ def generate_chat_response(
|
|||
chat_model = conversation_config.chat_model
|
||||
chat_response = converse(
|
||||
compiled_references,
|
||||
q,
|
||||
query_to_run,
|
||||
image_url=uploaded_image_url,
|
||||
online_results=online_results,
|
||||
conversation_log=meta_log,
|
||||
|
@ -977,7 +1035,7 @@ def generate_chat_response(
|
|||
api_key = conversation_config.openai_config.api_key
|
||||
chat_response = converse_anthropic(
|
||||
compiled_references,
|
||||
q,
|
||||
query_to_run,
|
||||
online_results,
|
||||
meta_log,
|
||||
model=conversation_config.chat_model,
|
||||
|
@ -994,7 +1052,7 @@ def generate_chat_response(
|
|||
api_key = conversation_config.openai_config.api_key
|
||||
chat_response = converse_gemini(
|
||||
compiled_references,
|
||||
q,
|
||||
query_to_run,
|
||||
online_results,
|
||||
meta_log,
|
||||
model=conversation_config.chat_model,
|
||||
|
|
261
src/khoj/routers/research.py
Normal file
261
src/khoj/routers/research.py
Normal file
|
@ -0,0 +1,261 @@
|
|||
import json
|
||||
import logging
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from khoj.database.adapters import EntryAdapters
|
||||
from khoj.database.models import Agent, KhojUser
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.utils import remove_json_codeblock
|
||||
from khoj.processor.tools.online_search import read_webpages, search_online
|
||||
from khoj.routers.api import extract_references_and_questions
|
||||
from khoj.routers.helpers import (
|
||||
ChatEvent,
|
||||
construct_chat_history,
|
||||
generate_summary_from_files,
|
||||
send_message_to_model_wrapper,
|
||||
)
|
||||
from khoj.utils.helpers import (
|
||||
ConversationCommand,
|
||||
function_calling_description_for_llm,
|
||||
timer,
|
||||
)
|
||||
from khoj.utils.rawconfig import LocationData
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class InformationCollectionIteration:
|
||||
def __init__(
|
||||
self, data_source: str, query: str, context: str = None, onlineContext: str = None, result: Any = None
|
||||
):
|
||||
self.data_source = data_source
|
||||
self.query = query
|
||||
self.context = context
|
||||
self.onlineContext = onlineContext
|
||||
|
||||
|
||||
async def apick_next_tool(
|
||||
query: str,
|
||||
conversation_history: dict,
|
||||
subscribed: bool,
|
||||
uploaded_image_url: str = None,
|
||||
agent: Agent = None,
|
||||
previous_iterations: List[InformationCollectionIteration] = None,
|
||||
):
|
||||
"""
|
||||
Given a query, determine which of the available tools the agent should use in order to answer appropriately. One at a time, and it's able to use subsequent iterations to refine the answer.
|
||||
"""
|
||||
|
||||
tool_options = dict()
|
||||
tool_options_str = ""
|
||||
|
||||
agent_tools = agent.input_tools if agent else []
|
||||
|
||||
for tool, description in function_calling_description_for_llm.items():
|
||||
tool_options[tool.value] = description
|
||||
if len(agent_tools) == 0 or tool.value in agent_tools:
|
||||
tool_options_str += f'- "{tool.value}": "{description}"\n'
|
||||
|
||||
chat_history = construct_chat_history(conversation_history)
|
||||
|
||||
previous_iterations_history = ""
|
||||
for iteration in previous_iterations:
|
||||
iteration_data = prompts.previous_iteration.format(
|
||||
query=iteration.query,
|
||||
data_source=iteration.data_source,
|
||||
context=str(iteration.context),
|
||||
onlineContext=str(iteration.onlineContext),
|
||||
)
|
||||
|
||||
previous_iterations_history += iteration_data
|
||||
|
||||
if uploaded_image_url:
|
||||
query = f"[placeholder for user attached image]\n{query}"
|
||||
|
||||
personality_context = (
|
||||
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
|
||||
)
|
||||
|
||||
function_planning_prompt = prompts.plan_function_execution.format(
|
||||
query=query,
|
||||
tools=tool_options_str,
|
||||
chat_history=chat_history,
|
||||
personality_context=personality_context,
|
||||
previous_iterations=previous_iterations_history,
|
||||
)
|
||||
|
||||
with timer("Chat actor: Infer information sources to refer", logger):
|
||||
response = await send_message_to_model_wrapper(
|
||||
function_planning_prompt,
|
||||
response_type="json_object",
|
||||
subscribed=subscribed,
|
||||
)
|
||||
|
||||
try:
|
||||
response = response.strip()
|
||||
response = remove_json_codeblock(response)
|
||||
response = json.loads(response)
|
||||
suggested_data_source = response.get("data_source", None)
|
||||
suggested_query = response.get("query", None)
|
||||
|
||||
return InformationCollectionIteration(
|
||||
data_source=suggested_data_source,
|
||||
query=suggested_query,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Invalid response for determining relevant tools: {response}. {e}", exc_info=True)
|
||||
return InformationCollectionIteration(
|
||||
data_source=None,
|
||||
query=None,
|
||||
)
|
||||
|
||||
|
||||
async def execute_information_collection(
|
||||
request: Request,
|
||||
user: KhojUser,
|
||||
query: str,
|
||||
conversation_id: str,
|
||||
conversation_history: dict,
|
||||
subscribed: bool,
|
||||
uploaded_image_url: str = None,
|
||||
agent: Agent = None,
|
||||
send_status_func: Optional[Callable] = None,
|
||||
location: LocationData = None,
|
||||
file_filters: List[str] = [],
|
||||
):
|
||||
iteration = 0
|
||||
MAX_ITERATIONS = 2
|
||||
previous_iterations = []
|
||||
while iteration < MAX_ITERATIONS:
|
||||
online_results: Dict = dict()
|
||||
compiled_references, inferred_queries, defiltered_query = [], [], None
|
||||
this_iteration = await apick_next_tool(
|
||||
query, conversation_history, subscribed, uploaded_image_url, agent, previous_iterations
|
||||
)
|
||||
if this_iteration.data_source == ConversationCommand.Notes:
|
||||
## Extract Document References
|
||||
compiled_references, inferred_queries, defiltered_query = [], [], None
|
||||
async for result in extract_references_and_questions(
|
||||
request,
|
||||
conversation_history,
|
||||
this_iteration.query,
|
||||
7,
|
||||
None,
|
||||
conversation_id,
|
||||
[ConversationCommand.Default],
|
||||
location,
|
||||
send_status_func,
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
agent=agent,
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
yield result[ChatEvent.STATUS]
|
||||
else:
|
||||
compiled_references.extend(result[0])
|
||||
inferred_queries.extend(result[1])
|
||||
defiltered_query = result[2]
|
||||
previous_iterations.append(
|
||||
InformationCollectionIteration(
|
||||
data_source=this_iteration.data_source,
|
||||
query=this_iteration.query,
|
||||
context=str(compiled_references),
|
||||
)
|
||||
)
|
||||
|
||||
elif this_iteration.data_source == ConversationCommand.Online:
|
||||
async for result in search_online(
|
||||
this_iteration.query,
|
||||
conversation_history,
|
||||
location,
|
||||
user,
|
||||
subscribed,
|
||||
send_status_func,
|
||||
[],
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
agent=agent,
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
yield result[ChatEvent.STATUS]
|
||||
else:
|
||||
online_results = result
|
||||
previous_iterations.append(
|
||||
InformationCollectionIteration(
|
||||
data_source=this_iteration.data_source,
|
||||
query=this_iteration.query,
|
||||
onlineContext=online_results,
|
||||
)
|
||||
)
|
||||
|
||||
elif this_iteration.data_source == ConversationCommand.Webpage:
|
||||
async for result in read_webpages(
|
||||
this_iteration.query,
|
||||
conversation_history,
|
||||
location,
|
||||
user,
|
||||
subscribed,
|
||||
send_status_func,
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
agent=agent,
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
yield result[ChatEvent.STATUS]
|
||||
else:
|
||||
direct_web_pages = result
|
||||
|
||||
webpages = []
|
||||
for query in direct_web_pages:
|
||||
if online_results.get(query):
|
||||
online_results[query]["webpages"] = direct_web_pages[query]["webpages"]
|
||||
else:
|
||||
online_results[query] = {"webpages": direct_web_pages[query]["webpages"]}
|
||||
|
||||
for webpage in direct_web_pages[query]["webpages"]:
|
||||
webpages.append(webpage["link"])
|
||||
yield send_status_func(f"**Read web pages**: {webpages}")
|
||||
|
||||
previous_iterations.append(
|
||||
InformationCollectionIteration(
|
||||
data_source=this_iteration.data_source,
|
||||
query=this_iteration.query,
|
||||
onlineContext=online_results,
|
||||
)
|
||||
)
|
||||
|
||||
elif this_iteration.data_source == ConversationCommand.Summarize:
|
||||
response_log = ""
|
||||
agent_has_entries = await EntryAdapters.aagent_has_entries(agent)
|
||||
if len(file_filters) == 0 and not agent_has_entries:
|
||||
previous_iterations.append(
|
||||
InformationCollectionIteration(
|
||||
data_source=this_iteration.data_source,
|
||||
query=this_iteration.query,
|
||||
context="No files selected for summarization.",
|
||||
)
|
||||
)
|
||||
elif len(file_filters) > 1 and not agent_has_entries:
|
||||
response_log = "Only one file can be selected for summarization."
|
||||
previous_iterations.append(
|
||||
InformationCollectionIteration(
|
||||
data_source=this_iteration.data_source,
|
||||
query=this_iteration.query,
|
||||
context=response_log,
|
||||
)
|
||||
)
|
||||
else:
|
||||
response_log = await generate_summary_from_files(
|
||||
q=query,
|
||||
user=user,
|
||||
file_filters=file_filters,
|
||||
meta_log=conversation_history,
|
||||
subscribed=subscribed,
|
||||
send_status_func=send_status_func,
|
||||
)
|
||||
else:
|
||||
iteration = MAX_ITERATIONS
|
||||
|
||||
iteration += 1
|
||||
for completed_iter in previous_iterations:
|
||||
yield completed_iter
|
|
@ -345,6 +345,13 @@ tool_descriptions_for_llm = {
|
|||
ConversationCommand.Summarize: "To retrieve an answer that depends on the entire document or a large text.",
|
||||
}
|
||||
|
||||
function_calling_description_for_llm = {
|
||||
ConversationCommand.Notes: "Use this if you think the user's personal knowledge base contains relevant context.",
|
||||
ConversationCommand.Online: "Use this if you think the there's important information on the internet related to the query.",
|
||||
ConversationCommand.Webpage: "Use this if the user has provided a webpage URL or you are share of a webpage URL that will help you directly answer this query",
|
||||
ConversationCommand.Summarize: "Use this if you want to retrieve an answer that depends on reading an entire corpus.",
|
||||
}
|
||||
|
||||
mode_descriptions_for_llm = {
|
||||
ConversationCommand.Image: "Use this if the user is requesting you to generate a picture based on their description.",
|
||||
ConversationCommand.Automation: "Use this if the user is requesting a response at a scheduled date or time.",
|
||||
|
|
Loading…
Add table
Reference in a new issue