mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-28 01:45:07 +01:00
Pass previous iteration results to code interpreter chat actors
This improves the code interpreter chat actors abilitiy to generate code with data collected during the previous iterations
This commit is contained in:
parent
9e7025b330
commit
61df1d5db8
5 changed files with 53 additions and 35 deletions
|
@ -753,7 +753,10 @@ For example:
|
||||||
{{"codes": ["print('Hello, World!')", "print('Goodbye, World!')"]}}
|
{{"codes": ["print('Hello, World!')", "print('Goodbye, World!')"]}}
|
||||||
|
|
||||||
Now it's your turn to construct python programs to answer the user's question. Provide them as a list of strings in a JSON object. Do not say anything else.
|
Now it's your turn to construct python programs to answer the user's question. Provide them as a list of strings in a JSON object. Do not say anything else.
|
||||||
History:
|
Data from Previous Iterations:
|
||||||
|
{previous_iterations_history}
|
||||||
|
|
||||||
|
Chat History:
|
||||||
{chat_history}
|
{chat_history}
|
||||||
|
|
||||||
User: {query}
|
User: {query}
|
||||||
|
|
|
@ -79,6 +79,40 @@ class ThreadedGenerator:
|
||||||
self.queue.put(StopIteration)
|
self.queue.put(StopIteration)
|
||||||
|
|
||||||
|
|
||||||
|
class InformationCollectionIteration:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
data_source: str,
|
||||||
|
query: str,
|
||||||
|
context: Dict[str, Dict] = None,
|
||||||
|
onlineContext: dict = None,
|
||||||
|
codeContext: dict = None,
|
||||||
|
summarizedResult: str = None,
|
||||||
|
):
|
||||||
|
self.data_source = data_source
|
||||||
|
self.query = query
|
||||||
|
self.context = context
|
||||||
|
self.onlineContext = onlineContext
|
||||||
|
self.codeContext = codeContext
|
||||||
|
self.summarizedResult = summarizedResult
|
||||||
|
|
||||||
|
|
||||||
|
def construct_iteration_history(
|
||||||
|
previous_iterations: List[InformationCollectionIteration], previous_iteration_prompt: str
|
||||||
|
) -> str:
|
||||||
|
previous_iterations_history = ""
|
||||||
|
for idx, iteration in enumerate(previous_iterations):
|
||||||
|
iteration_data = previous_iteration_prompt.format(
|
||||||
|
query=iteration.query,
|
||||||
|
data_source=iteration.data_source,
|
||||||
|
summary=iteration.summarizedResult,
|
||||||
|
index=idx + 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
previous_iterations_history += iteration_data
|
||||||
|
return previous_iterations_history
|
||||||
|
|
||||||
|
|
||||||
def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="AI") -> str:
|
def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="AI") -> str:
|
||||||
chat_history = ""
|
chat_history = ""
|
||||||
for chat in conversation_history.get("chat", [])[-n:]:
|
for chat in conversation_history.get("chat", [])[-n:]:
|
||||||
|
|
|
@ -28,6 +28,7 @@ SANDBOX_URL = os.getenv("KHOJ_TERRARIUM_URL", "http://localhost:8080")
|
||||||
async def run_code(
|
async def run_code(
|
||||||
query: str,
|
query: str,
|
||||||
conversation_history: dict,
|
conversation_history: dict,
|
||||||
|
previous_iterations_history: str,
|
||||||
location_data: LocationData,
|
location_data: LocationData,
|
||||||
user: KhojUser,
|
user: KhojUser,
|
||||||
send_status_func: Optional[Callable] = None,
|
send_status_func: Optional[Callable] = None,
|
||||||
|
@ -42,7 +43,7 @@ async def run_code(
|
||||||
try:
|
try:
|
||||||
with timer("Chat actor: Generate programs to execute", logger):
|
with timer("Chat actor: Generate programs to execute", logger):
|
||||||
codes = await generate_python_code(
|
codes = await generate_python_code(
|
||||||
query, conversation_history, location_data, user, uploaded_image_url, agent
|
query, conversation_history, previous_iterations_history, location_data, user, uploaded_image_url, agent
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(f"Failed to generate code for {query} with error: {e}")
|
raise ValueError(f"Failed to generate code for {query} with error: {e}")
|
||||||
|
@ -66,6 +67,7 @@ async def run_code(
|
||||||
async def generate_python_code(
|
async def generate_python_code(
|
||||||
q: str,
|
q: str,
|
||||||
conversation_history: dict,
|
conversation_history: dict,
|
||||||
|
previous_iterations_history: str,
|
||||||
location_data: LocationData,
|
location_data: LocationData,
|
||||||
user: KhojUser,
|
user: KhojUser,
|
||||||
uploaded_image_url: str = None,
|
uploaded_image_url: str = None,
|
||||||
|
@ -85,6 +87,7 @@ async def generate_python_code(
|
||||||
current_date=utc_date,
|
current_date=utc_date,
|
||||||
query=q,
|
query=q,
|
||||||
chat_history=chat_history,
|
chat_history=chat_history,
|
||||||
|
previous_iterations_history=previous_iterations_history,
|
||||||
location=location,
|
location=location,
|
||||||
username=username,
|
username=username,
|
||||||
personality_context=personality_context,
|
personality_context=personality_context,
|
||||||
|
|
|
@ -954,9 +954,11 @@ async def chat(
|
||||||
## Gather Code Results
|
## Gather Code Results
|
||||||
if ConversationCommand.Code in conversation_commands:
|
if ConversationCommand.Code in conversation_commands:
|
||||||
try:
|
try:
|
||||||
|
previous_iteration_history = ""
|
||||||
async for result in run_code(
|
async for result in run_code(
|
||||||
defiltered_query,
|
defiltered_query,
|
||||||
meta_log,
|
meta_log,
|
||||||
|
previous_iteration_history,
|
||||||
location,
|
location,
|
||||||
user,
|
user,
|
||||||
partial(send_event, ChatEvent.STATUS),
|
partial(send_event, ChatEvent.STATUS),
|
||||||
|
|
|
@ -8,7 +8,11 @@ from fastapi import Request
|
||||||
from khoj.database.adapters import ConversationAdapters, EntryAdapters
|
from khoj.database.adapters import ConversationAdapters, EntryAdapters
|
||||||
from khoj.database.models import Agent, KhojUser
|
from khoj.database.models import Agent, KhojUser
|
||||||
from khoj.processor.conversation import prompts
|
from khoj.processor.conversation import prompts
|
||||||
from khoj.processor.conversation.utils import remove_json_codeblock
|
from khoj.processor.conversation.utils import (
|
||||||
|
InformationCollectionIteration,
|
||||||
|
construct_iteration_history,
|
||||||
|
remove_json_codeblock,
|
||||||
|
)
|
||||||
from khoj.processor.tools.online_search import read_webpages, search_online
|
from khoj.processor.tools.online_search import read_webpages, search_online
|
||||||
from khoj.processor.tools.run_code import run_code
|
from khoj.processor.tools.run_code import run_code
|
||||||
from khoj.routers.api import extract_references_and_questions
|
from khoj.routers.api import extract_references_and_questions
|
||||||
|
@ -30,24 +34,6 @@ from khoj.utils.rawconfig import LocationData
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class InformationCollectionIteration:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
data_source: str,
|
|
||||||
query: str,
|
|
||||||
context: Dict[str, Dict] = None,
|
|
||||||
onlineContext: dict = None,
|
|
||||||
codeContext: dict = None,
|
|
||||||
summarizedResult: str = None,
|
|
||||||
):
|
|
||||||
self.data_source = data_source
|
|
||||||
self.query = query
|
|
||||||
self.context = context
|
|
||||||
self.onlineContext = onlineContext
|
|
||||||
self.codeContext = codeContext
|
|
||||||
self.summarizedResult = summarizedResult
|
|
||||||
|
|
||||||
|
|
||||||
async def apick_next_tool(
|
async def apick_next_tool(
|
||||||
query: str,
|
query: str,
|
||||||
conversation_history: dict,
|
conversation_history: dict,
|
||||||
|
@ -56,7 +42,7 @@ async def apick_next_tool(
|
||||||
location: LocationData = None,
|
location: LocationData = None,
|
||||||
user_name: str = None,
|
user_name: str = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
previous_iterations: List[InformationCollectionIteration] = None,
|
previous_iterations_history: str = None,
|
||||||
max_iterations: int = 5,
|
max_iterations: int = 5,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
@ -75,17 +61,6 @@ async def apick_next_tool(
|
||||||
|
|
||||||
chat_history = construct_chat_history(conversation_history)
|
chat_history = construct_chat_history(conversation_history)
|
||||||
|
|
||||||
previous_iterations_history = ""
|
|
||||||
for idx, iteration in enumerate(previous_iterations):
|
|
||||||
iteration_data = prompts.previous_iteration.format(
|
|
||||||
query=iteration.query,
|
|
||||||
data_source=iteration.data_source,
|
|
||||||
summary=iteration.summarizedResult,
|
|
||||||
index=idx + 1,
|
|
||||||
)
|
|
||||||
|
|
||||||
previous_iterations_history += iteration_data
|
|
||||||
|
|
||||||
if uploaded_image_url:
|
if uploaded_image_url:
|
||||||
query = f"[placeholder for user attached image]\n{query}"
|
query = f"[placeholder for user attached image]\n{query}"
|
||||||
|
|
||||||
|
@ -98,7 +73,6 @@ async def apick_next_tool(
|
||||||
location_data = f"{location}" if location else "Unknown"
|
location_data = f"{location}" if location else "Unknown"
|
||||||
username = prompts.user_name.format(name=user_name) if user_name else ""
|
username = prompts.user_name.format(name=user_name) if user_name else ""
|
||||||
|
|
||||||
# TODO Add current date/time to the query
|
|
||||||
function_planning_prompt = prompts.plan_function_execution.format(
|
function_planning_prompt = prompts.plan_function_execution.format(
|
||||||
query=query,
|
query=query,
|
||||||
tools=tool_options_str,
|
tools=tool_options_str,
|
||||||
|
@ -166,6 +140,7 @@ async def execute_information_collection(
|
||||||
code_results: Dict = dict()
|
code_results: Dict = dict()
|
||||||
compiled_references: List[Any] = []
|
compiled_references: List[Any] = []
|
||||||
inferred_queries: List[Any] = []
|
inferred_queries: List[Any] = []
|
||||||
|
previous_iterations_history = construct_iteration_history(previous_iterations, prompts.previous_iteration)
|
||||||
|
|
||||||
result: str = ""
|
result: str = ""
|
||||||
|
|
||||||
|
@ -177,7 +152,7 @@ async def execute_information_collection(
|
||||||
location,
|
location,
|
||||||
user_name,
|
user_name,
|
||||||
agent,
|
agent,
|
||||||
previous_iterations,
|
previous_iterations_history,
|
||||||
MAX_ITERATIONS,
|
MAX_ITERATIONS,
|
||||||
)
|
)
|
||||||
if this_iteration.data_source == ConversationCommand.Notes:
|
if this_iteration.data_source == ConversationCommand.Notes:
|
||||||
|
@ -268,6 +243,7 @@ async def execute_information_collection(
|
||||||
async for result in run_code(
|
async for result in run_code(
|
||||||
this_iteration.query,
|
this_iteration.query,
|
||||||
conversation_history,
|
conversation_history,
|
||||||
|
previous_iterations_history,
|
||||||
location,
|
location,
|
||||||
user,
|
user,
|
||||||
send_status_func,
|
send_status_func,
|
||||||
|
|
Loading…
Reference in a new issue