From 61df1d5db87551e6a327949d283094ab77d8cb79 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Thu, 10 Oct 2024 00:59:25 -0700 Subject: [PATCH] Pass previous iteration results to code interpreter chat actors This improves the code interpreter chat actors abilitiy to generate code with data collected during the previous iterations --- src/khoj/processor/conversation/prompts.py | 5 ++- src/khoj/processor/conversation/utils.py | 34 ++++++++++++++++++ src/khoj/processor/tools/run_code.py | 5 ++- src/khoj/routers/api_chat.py | 2 ++ src/khoj/routers/research.py | 42 +++++----------------- 5 files changed, 53 insertions(+), 35 deletions(-) diff --git a/src/khoj/processor/conversation/prompts.py b/src/khoj/processor/conversation/prompts.py index 23788cab..a82dc18f 100644 --- a/src/khoj/processor/conversation/prompts.py +++ b/src/khoj/processor/conversation/prompts.py @@ -753,7 +753,10 @@ For example: {{"codes": ["print('Hello, World!')", "print('Goodbye, World!')"]}} Now it's your turn to construct python programs to answer the user's question. Provide them as a list of strings in a JSON object. Do not say anything else. -History: +Data from Previous Iterations: +{previous_iterations_history} + +Chat History: {chat_history} User: {query} diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 339024ae..b8960e0b 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -79,6 +79,40 @@ class ThreadedGenerator: self.queue.put(StopIteration) +class InformationCollectionIteration: + def __init__( + self, + data_source: str, + query: str, + context: Dict[str, Dict] = None, + onlineContext: dict = None, + codeContext: dict = None, + summarizedResult: str = None, + ): + self.data_source = data_source + self.query = query + self.context = context + self.onlineContext = onlineContext + self.codeContext = codeContext + self.summarizedResult = summarizedResult + + +def construct_iteration_history( + previous_iterations: List[InformationCollectionIteration], previous_iteration_prompt: str +) -> str: + previous_iterations_history = "" + for idx, iteration in enumerate(previous_iterations): + iteration_data = previous_iteration_prompt.format( + query=iteration.query, + data_source=iteration.data_source, + summary=iteration.summarizedResult, + index=idx + 1, + ) + + previous_iterations_history += iteration_data + return previous_iterations_history + + def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="AI") -> str: chat_history = "" for chat in conversation_history.get("chat", [])[-n:]: diff --git a/src/khoj/processor/tools/run_code.py b/src/khoj/processor/tools/run_code.py index 681c5f94..384b993c 100644 --- a/src/khoj/processor/tools/run_code.py +++ b/src/khoj/processor/tools/run_code.py @@ -28,6 +28,7 @@ SANDBOX_URL = os.getenv("KHOJ_TERRARIUM_URL", "http://localhost:8080") async def run_code( query: str, conversation_history: dict, + previous_iterations_history: str, location_data: LocationData, user: KhojUser, send_status_func: Optional[Callable] = None, @@ -42,7 +43,7 @@ async def run_code( try: with timer("Chat actor: Generate programs to execute", logger): codes = await generate_python_code( - query, conversation_history, location_data, user, uploaded_image_url, agent + query, conversation_history, previous_iterations_history, location_data, user, uploaded_image_url, agent ) except Exception as e: raise ValueError(f"Failed to generate code for {query} with error: {e}") @@ -66,6 +67,7 @@ async def run_code( async def generate_python_code( q: str, conversation_history: dict, + previous_iterations_history: str, location_data: LocationData, user: KhojUser, uploaded_image_url: str = None, @@ -85,6 +87,7 @@ async def generate_python_code( current_date=utc_date, query=q, chat_history=chat_history, + previous_iterations_history=previous_iterations_history, location=location, username=username, personality_context=personality_context, diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 60d414de..2a9654ef 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -954,9 +954,11 @@ async def chat( ## Gather Code Results if ConversationCommand.Code in conversation_commands: try: + previous_iteration_history = "" async for result in run_code( defiltered_query, meta_log, + previous_iteration_history, location, user, partial(send_event, ChatEvent.STATUS), diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index 9a5e3169..1ada9e7a 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -8,7 +8,11 @@ from fastapi import Request from khoj.database.adapters import ConversationAdapters, EntryAdapters from khoj.database.models import Agent, KhojUser from khoj.processor.conversation import prompts -from khoj.processor.conversation.utils import remove_json_codeblock +from khoj.processor.conversation.utils import ( + InformationCollectionIteration, + construct_iteration_history, + remove_json_codeblock, +) from khoj.processor.tools.online_search import read_webpages, search_online from khoj.processor.tools.run_code import run_code from khoj.routers.api import extract_references_and_questions @@ -30,24 +34,6 @@ from khoj.utils.rawconfig import LocationData logger = logging.getLogger(__name__) -class InformationCollectionIteration: - def __init__( - self, - data_source: str, - query: str, - context: Dict[str, Dict] = None, - onlineContext: dict = None, - codeContext: dict = None, - summarizedResult: str = None, - ): - self.data_source = data_source - self.query = query - self.context = context - self.onlineContext = onlineContext - self.codeContext = codeContext - self.summarizedResult = summarizedResult - - async def apick_next_tool( query: str, conversation_history: dict, @@ -56,7 +42,7 @@ async def apick_next_tool( location: LocationData = None, user_name: str = None, agent: Agent = None, - previous_iterations: List[InformationCollectionIteration] = None, + previous_iterations_history: str = None, max_iterations: int = 5, ): """ @@ -75,17 +61,6 @@ async def apick_next_tool( chat_history = construct_chat_history(conversation_history) - previous_iterations_history = "" - for idx, iteration in enumerate(previous_iterations): - iteration_data = prompts.previous_iteration.format( - query=iteration.query, - data_source=iteration.data_source, - summary=iteration.summarizedResult, - index=idx + 1, - ) - - previous_iterations_history += iteration_data - if uploaded_image_url: query = f"[placeholder for user attached image]\n{query}" @@ -98,7 +73,6 @@ async def apick_next_tool( location_data = f"{location}" if location else "Unknown" username = prompts.user_name.format(name=user_name) if user_name else "" - # TODO Add current date/time to the query function_planning_prompt = prompts.plan_function_execution.format( query=query, tools=tool_options_str, @@ -166,6 +140,7 @@ async def execute_information_collection( code_results: Dict = dict() compiled_references: List[Any] = [] inferred_queries: List[Any] = [] + previous_iterations_history = construct_iteration_history(previous_iterations, prompts.previous_iteration) result: str = "" @@ -177,7 +152,7 @@ async def execute_information_collection( location, user_name, agent, - previous_iterations, + previous_iterations_history, MAX_ITERATIONS, ) if this_iteration.data_source == ConversationCommand.Notes: @@ -268,6 +243,7 @@ async def execute_information_collection( async for result in run_code( this_iteration.query, conversation_history, + previous_iterations_history, location, user, send_status_func,