Pass tool AIs iteration history as chat history for better context

Separate conversation history with user from the conversation history
between the tool AIs and the researcher AI.

Tools AIs don't need top level conversation history, that context is
meant for the researcher AI.

The invoked tool AIs need previous attempts at using the tool in this
research runs iteration history to better tune their next run.

Or at least that is the hypothesis to break the models looping.
This commit is contained in:
Debanjum 2024-10-28 19:21:20 -07:00
parent d865994062
commit dc8e89b5de
2 changed files with 49 additions and 9 deletions

View file

@ -9,7 +9,7 @@ from datetime import datetime
from enum import Enum from enum import Enum
from io import BytesIO from io import BytesIO
from time import perf_counter from time import perf_counter
from typing import Any, Dict, List, Optional from typing import Any, Callable, Dict, List, Optional
import PIL.Image import PIL.Image
import requests import requests
@ -29,7 +29,12 @@ from khoj.search_filter.date_filter import DateFilter
from khoj.search_filter.file_filter import FileFilter from khoj.search_filter.file_filter import FileFilter
from khoj.search_filter.word_filter import WordFilter from khoj.search_filter.word_filter import WordFilter
from khoj.utils import state from khoj.utils import state
from khoj.utils.helpers import in_debug_mode, is_none_or_empty, merge_dicts from khoj.utils.helpers import (
ConversationCommand,
in_debug_mode,
is_none_or_empty,
merge_dicts,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
model_to_prompt_size = { model_to_prompt_size = {
@ -139,6 +144,41 @@ def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="A
return chat_history return chat_history
def construct_tool_chat_history(
previous_iterations: List[InformationCollectionIteration], tool: ConversationCommand = None
) -> Dict[str, list]:
chat_history: list = []
inferred_query_extractor: Callable[[InformationCollectionIteration], List[str]] = lambda x: []
if tool == ConversationCommand.Notes:
inferred_query_extractor = (
lambda iteration: [c["query"] for c in iteration.context] if iteration.context else []
)
elif tool == ConversationCommand.Online:
inferred_query_extractor = (
lambda iteration: list(iteration.onlineContext.keys()) if iteration.onlineContext else []
)
elif tool == ConversationCommand.Code:
inferred_query_extractor = lambda iteration: list(iteration.codeContext.keys()) if iteration.codeContext else []
for iteration in previous_iterations:
chat_history += [
{
"by": "you",
"message": iteration.query,
},
{
"by": "khoj",
"intent": {
"type": "remember",
"inferred-queries": inferred_query_extractor(iteration),
"query": iteration.query,
},
"message": iteration.summarizedResult,
},
]
return {"chat": chat_history}
class ChatEvent(Enum): class ChatEvent(Enum):
START_LLM_RESPONSE = "start_llm_response" START_LLM_RESPONSE = "start_llm_response"
END_LLM_RESPONSE = "end_llm_response" END_LLM_RESPONSE = "end_llm_response"

View file

@ -12,6 +12,7 @@ from khoj.processor.conversation import prompts
from khoj.processor.conversation.utils import ( from khoj.processor.conversation.utils import (
InformationCollectionIteration, InformationCollectionIteration,
construct_iteration_history, construct_iteration_history,
construct_tool_chat_history,
remove_json_codeblock, remove_json_codeblock,
) )
from khoj.processor.tools.online_search import read_webpages, search_online from khoj.processor.tools.online_search import read_webpages, search_online
@ -147,7 +148,6 @@ async def execute_information_collection(
code_results: Dict = dict() code_results: Dict = dict()
document_results: List[Dict[str, str]] = [] document_results: List[Dict[str, str]] = []
summarize_files: str = "" summarize_files: str = ""
inferred_queries: List[Any] = []
this_iteration = InformationCollectionIteration(tool=None, query=query) this_iteration = InformationCollectionIteration(tool=None, query=query)
previous_iterations_history = construct_iteration_history(previous_iterations, prompts.previous_iteration) previous_iterations_history = construct_iteration_history(previous_iterations, prompts.previous_iteration)
@ -174,7 +174,7 @@ async def execute_information_collection(
document_results = [] document_results = []
async for result in extract_references_and_questions( async for result in extract_references_and_questions(
request, request,
conversation_history, construct_tool_chat_history(previous_iterations, ConversationCommand.Notes),
this_iteration.query, this_iteration.query,
7, 7,
None, None,
@ -208,7 +208,7 @@ async def execute_information_collection(
elif this_iteration.tool == ConversationCommand.Online: elif this_iteration.tool == ConversationCommand.Online:
async for result in search_online( async for result in search_online(
this_iteration.query, this_iteration.query,
conversation_history, construct_tool_chat_history(previous_iterations, ConversationCommand.Online),
location, location,
user, user,
send_status_func, send_status_func,
@ -228,7 +228,7 @@ async def execute_information_collection(
try: try:
async for result in read_webpages( async for result in read_webpages(
this_iteration.query, this_iteration.query,
conversation_history, construct_tool_chat_history(previous_iterations, ConversationCommand.Webpage),
location, location,
user, user,
send_status_func, send_status_func,
@ -258,8 +258,8 @@ async def execute_information_collection(
try: try:
async for result in run_code( async for result in run_code(
this_iteration.query, this_iteration.query,
conversation_history, construct_tool_chat_history(previous_iterations, ConversationCommand.Webpage),
previous_iterations_history, "",
location, location,
user, user,
send_status_func, send_status_func,
@ -286,7 +286,7 @@ async def execute_information_collection(
this_iteration.query, this_iteration.query,
user, user,
file_filters, file_filters,
conversation_history, construct_tool_chat_history(previous_iterations),
query_images=query_images, query_images=query_images,
agent=agent, agent=agent,
send_status_func=send_status_func, send_status_func=send_status_func,