mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 15:38:55 +01:00
Pass tool AIs iteration history as chat history for better context
Separate conversation history with user from the conversation history between the tool AIs and the researcher AI. Tools AIs don't need top level conversation history, that context is meant for the researcher AI. The invoked tool AIs need previous attempts at using the tool in this research runs iteration history to better tune their next run. Or at least that is the hypothesis to break the models looping.
This commit is contained in:
parent
d865994062
commit
dc8e89b5de
2 changed files with 49 additions and 9 deletions
|
@ -9,7 +9,7 @@ from datetime import datetime
|
|||
from enum import Enum
|
||||
from io import BytesIO
|
||||
from time import perf_counter
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import PIL.Image
|
||||
import requests
|
||||
|
@ -29,7 +29,12 @@ from khoj.search_filter.date_filter import DateFilter
|
|||
from khoj.search_filter.file_filter import FileFilter
|
||||
from khoj.search_filter.word_filter import WordFilter
|
||||
from khoj.utils import state
|
||||
from khoj.utils.helpers import in_debug_mode, is_none_or_empty, merge_dicts
|
||||
from khoj.utils.helpers import (
|
||||
ConversationCommand,
|
||||
in_debug_mode,
|
||||
is_none_or_empty,
|
||||
merge_dicts,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
model_to_prompt_size = {
|
||||
|
@ -139,6 +144,41 @@ def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="A
|
|||
return chat_history
|
||||
|
||||
|
||||
def construct_tool_chat_history(
|
||||
previous_iterations: List[InformationCollectionIteration], tool: ConversationCommand = None
|
||||
) -> Dict[str, list]:
|
||||
chat_history: list = []
|
||||
inferred_query_extractor: Callable[[InformationCollectionIteration], List[str]] = lambda x: []
|
||||
if tool == ConversationCommand.Notes:
|
||||
inferred_query_extractor = (
|
||||
lambda iteration: [c["query"] for c in iteration.context] if iteration.context else []
|
||||
)
|
||||
elif tool == ConversationCommand.Online:
|
||||
inferred_query_extractor = (
|
||||
lambda iteration: list(iteration.onlineContext.keys()) if iteration.onlineContext else []
|
||||
)
|
||||
elif tool == ConversationCommand.Code:
|
||||
inferred_query_extractor = lambda iteration: list(iteration.codeContext.keys()) if iteration.codeContext else []
|
||||
for iteration in previous_iterations:
|
||||
chat_history += [
|
||||
{
|
||||
"by": "you",
|
||||
"message": iteration.query,
|
||||
},
|
||||
{
|
||||
"by": "khoj",
|
||||
"intent": {
|
||||
"type": "remember",
|
||||
"inferred-queries": inferred_query_extractor(iteration),
|
||||
"query": iteration.query,
|
||||
},
|
||||
"message": iteration.summarizedResult,
|
||||
},
|
||||
]
|
||||
|
||||
return {"chat": chat_history}
|
||||
|
||||
|
||||
class ChatEvent(Enum):
|
||||
START_LLM_RESPONSE = "start_llm_response"
|
||||
END_LLM_RESPONSE = "end_llm_response"
|
||||
|
|
|
@ -12,6 +12,7 @@ from khoj.processor.conversation import prompts
|
|||
from khoj.processor.conversation.utils import (
|
||||
InformationCollectionIteration,
|
||||
construct_iteration_history,
|
||||
construct_tool_chat_history,
|
||||
remove_json_codeblock,
|
||||
)
|
||||
from khoj.processor.tools.online_search import read_webpages, search_online
|
||||
|
@ -147,7 +148,6 @@ async def execute_information_collection(
|
|||
code_results: Dict = dict()
|
||||
document_results: List[Dict[str, str]] = []
|
||||
summarize_files: str = ""
|
||||
inferred_queries: List[Any] = []
|
||||
this_iteration = InformationCollectionIteration(tool=None, query=query)
|
||||
previous_iterations_history = construct_iteration_history(previous_iterations, prompts.previous_iteration)
|
||||
|
||||
|
@ -174,7 +174,7 @@ async def execute_information_collection(
|
|||
document_results = []
|
||||
async for result in extract_references_and_questions(
|
||||
request,
|
||||
conversation_history,
|
||||
construct_tool_chat_history(previous_iterations, ConversationCommand.Notes),
|
||||
this_iteration.query,
|
||||
7,
|
||||
None,
|
||||
|
@ -208,7 +208,7 @@ async def execute_information_collection(
|
|||
elif this_iteration.tool == ConversationCommand.Online:
|
||||
async for result in search_online(
|
||||
this_iteration.query,
|
||||
conversation_history,
|
||||
construct_tool_chat_history(previous_iterations, ConversationCommand.Online),
|
||||
location,
|
||||
user,
|
||||
send_status_func,
|
||||
|
@ -228,7 +228,7 @@ async def execute_information_collection(
|
|||
try:
|
||||
async for result in read_webpages(
|
||||
this_iteration.query,
|
||||
conversation_history,
|
||||
construct_tool_chat_history(previous_iterations, ConversationCommand.Webpage),
|
||||
location,
|
||||
user,
|
||||
send_status_func,
|
||||
|
@ -258,8 +258,8 @@ async def execute_information_collection(
|
|||
try:
|
||||
async for result in run_code(
|
||||
this_iteration.query,
|
||||
conversation_history,
|
||||
previous_iterations_history,
|
||||
construct_tool_chat_history(previous_iterations, ConversationCommand.Webpage),
|
||||
"",
|
||||
location,
|
||||
user,
|
||||
send_status_func,
|
||||
|
@ -286,7 +286,7 @@ async def execute_information_collection(
|
|||
this_iteration.query,
|
||||
user,
|
||||
file_filters,
|
||||
conversation_history,
|
||||
construct_tool_chat_history(previous_iterations),
|
||||
query_images=query_images,
|
||||
agent=agent,
|
||||
send_status_func=send_status_func,
|
||||
|
|
Loading…
Reference in a new issue