Pass previous iteration results to code interpreter chat actors

This improves the code interpreter chat actors abilitiy to generate
code with data collected during the previous iterations
This commit is contained in:
Debanjum Singh Solanky 2024-10-10 00:59:25 -07:00
parent 9e7025b330
commit 61df1d5db8
5 changed files with 53 additions and 35 deletions

View file

@ -753,7 +753,10 @@ For example:
{{"codes": ["print('Hello, World!')", "print('Goodbye, World!')"]}}
Now it's your turn to construct python programs to answer the user's question. Provide them as a list of strings in a JSON object. Do not say anything else.
History:
Data from Previous Iterations:
{previous_iterations_history}
Chat History:
{chat_history}
User: {query}

View file

@ -79,6 +79,40 @@ class ThreadedGenerator:
self.queue.put(StopIteration)
class InformationCollectionIteration:
def __init__(
self,
data_source: str,
query: str,
context: Dict[str, Dict] = None,
onlineContext: dict = None,
codeContext: dict = None,
summarizedResult: str = None,
):
self.data_source = data_source
self.query = query
self.context = context
self.onlineContext = onlineContext
self.codeContext = codeContext
self.summarizedResult = summarizedResult
def construct_iteration_history(
previous_iterations: List[InformationCollectionIteration], previous_iteration_prompt: str
) -> str:
previous_iterations_history = ""
for idx, iteration in enumerate(previous_iterations):
iteration_data = previous_iteration_prompt.format(
query=iteration.query,
data_source=iteration.data_source,
summary=iteration.summarizedResult,
index=idx + 1,
)
previous_iterations_history += iteration_data
return previous_iterations_history
def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="AI") -> str:
chat_history = ""
for chat in conversation_history.get("chat", [])[-n:]:

View file

@ -28,6 +28,7 @@ SANDBOX_URL = os.getenv("KHOJ_TERRARIUM_URL", "http://localhost:8080")
async def run_code(
query: str,
conversation_history: dict,
previous_iterations_history: str,
location_data: LocationData,
user: KhojUser,
send_status_func: Optional[Callable] = None,
@ -42,7 +43,7 @@ async def run_code(
try:
with timer("Chat actor: Generate programs to execute", logger):
codes = await generate_python_code(
query, conversation_history, location_data, user, uploaded_image_url, agent
query, conversation_history, previous_iterations_history, location_data, user, uploaded_image_url, agent
)
except Exception as e:
raise ValueError(f"Failed to generate code for {query} with error: {e}")
@ -66,6 +67,7 @@ async def run_code(
async def generate_python_code(
q: str,
conversation_history: dict,
previous_iterations_history: str,
location_data: LocationData,
user: KhojUser,
uploaded_image_url: str = None,
@ -85,6 +87,7 @@ async def generate_python_code(
current_date=utc_date,
query=q,
chat_history=chat_history,
previous_iterations_history=previous_iterations_history,
location=location,
username=username,
personality_context=personality_context,

View file

@ -954,9 +954,11 @@ async def chat(
## Gather Code Results
if ConversationCommand.Code in conversation_commands:
try:
previous_iteration_history = ""
async for result in run_code(
defiltered_query,
meta_log,
previous_iteration_history,
location,
user,
partial(send_event, ChatEvent.STATUS),

View file

@ -8,7 +8,11 @@ from fastapi import Request
from khoj.database.adapters import ConversationAdapters, EntryAdapters
from khoj.database.models import Agent, KhojUser
from khoj.processor.conversation import prompts
from khoj.processor.conversation.utils import remove_json_codeblock
from khoj.processor.conversation.utils import (
InformationCollectionIteration,
construct_iteration_history,
remove_json_codeblock,
)
from khoj.processor.tools.online_search import read_webpages, search_online
from khoj.processor.tools.run_code import run_code
from khoj.routers.api import extract_references_and_questions
@ -30,24 +34,6 @@ from khoj.utils.rawconfig import LocationData
logger = logging.getLogger(__name__)
class InformationCollectionIteration:
def __init__(
self,
data_source: str,
query: str,
context: Dict[str, Dict] = None,
onlineContext: dict = None,
codeContext: dict = None,
summarizedResult: str = None,
):
self.data_source = data_source
self.query = query
self.context = context
self.onlineContext = onlineContext
self.codeContext = codeContext
self.summarizedResult = summarizedResult
async def apick_next_tool(
query: str,
conversation_history: dict,
@ -56,7 +42,7 @@ async def apick_next_tool(
location: LocationData = None,
user_name: str = None,
agent: Agent = None,
previous_iterations: List[InformationCollectionIteration] = None,
previous_iterations_history: str = None,
max_iterations: int = 5,
):
"""
@ -75,17 +61,6 @@ async def apick_next_tool(
chat_history = construct_chat_history(conversation_history)
previous_iterations_history = ""
for idx, iteration in enumerate(previous_iterations):
iteration_data = prompts.previous_iteration.format(
query=iteration.query,
data_source=iteration.data_source,
summary=iteration.summarizedResult,
index=idx + 1,
)
previous_iterations_history += iteration_data
if uploaded_image_url:
query = f"[placeholder for user attached image]\n{query}"
@ -98,7 +73,6 @@ async def apick_next_tool(
location_data = f"{location}" if location else "Unknown"
username = prompts.user_name.format(name=user_name) if user_name else ""
# TODO Add current date/time to the query
function_planning_prompt = prompts.plan_function_execution.format(
query=query,
tools=tool_options_str,
@ -166,6 +140,7 @@ async def execute_information_collection(
code_results: Dict = dict()
compiled_references: List[Any] = []
inferred_queries: List[Any] = []
previous_iterations_history = construct_iteration_history(previous_iterations, prompts.previous_iteration)
result: str = ""
@ -177,7 +152,7 @@ async def execute_information_collection(
location,
user_name,
agent,
previous_iterations,
previous_iterations_history,
MAX_ITERATIONS,
)
if this_iteration.data_source == ConversationCommand.Notes:
@ -268,6 +243,7 @@ async def execute_information_collection(
async for result in run_code(
this_iteration.query,
conversation_history,
previous_iterations_history,
location,
user,
send_status_func,