mirror of
https://github.com/khoj-ai/khoj.git
synced 2025-02-17 08:04:21 +00:00
Enable analysing user documents in code sandbox and other improvements
- Run one program at a time, instead of allowing model to pass multiple programs to run in parallel to simplify logic for model - Update prompt to give more example of complex, multi-line code - Allow passing user files as input into code sandbox for analysis - Log code execution timer at info level to evaluate execution latencies in production - Type the generated code for easier processing by caller functions
This commit is contained in:
parent
ba2471dc02
commit
7b39f2014a
2 changed files with 65 additions and 29 deletions
|
@ -870,25 +870,40 @@ Khoj:
|
||||||
# --
|
# --
|
||||||
python_code_generation_prompt = PromptTemplate.from_template(
|
python_code_generation_prompt = PromptTemplate.from_template(
|
||||||
"""
|
"""
|
||||||
You are Khoj, an advanced python programmer. You are tasked with constructing **up to three** python programs to best answer the user query.
|
You are Khoj, an advanced python programmer. You are tasked with constructing a python program to best answer the user query.
|
||||||
- The python program will run in a pyodide python sandbox with no network access.
|
- The python program will run in a pyodide python sandbox with no network access.
|
||||||
- You can write programs to run complex calculations, analyze data, create charts, generate documents to meticulously answer the query
|
- You can write programs to run complex calculations, analyze data, create charts, generate documents to meticulously answer the query.
|
||||||
- The sandbox has access to the standard library, matplotlib, panda, numpy, scipy, bs4, sympy, brotli, cryptography, fast-parquet
|
- The sandbox has access to the standard library, matplotlib, panda, numpy, scipy, bs4, sympy, brotli, cryptography, fast-parquet.
|
||||||
|
- List known file paths to required user documents in "input_files" and known links to required documents from the web in the "input_links" field.
|
||||||
|
- The python program should be self-contained. It can only read data generated by the program itself and from provided input_files, input_links by their basename (i.e filename excluding file path).
|
||||||
- Do not try display images or plots in the code directly. The code should save the image or plot to a file instead.
|
- Do not try display images or plots in the code directly. The code should save the image or plot to a file instead.
|
||||||
- Write any document, charts etc. to be shared with the user to file. These files can be seen by the user.
|
- Write any document, charts etc. to be shared with the user to file. These files can be seen by the user.
|
||||||
- Use as much context from the previous questions and answers as required to generate your code.
|
- Use as much context from the previous questions and answers as required to generate your code.
|
||||||
{personality_context}
|
{personality_context}
|
||||||
What code will you need to write, if any, to answer the user's question?
|
What code will you need to write to answer the user's question?
|
||||||
Provide code programs as a list of strings in a JSON object with key "codes".
|
|
||||||
Current Date: {current_date}
|
Current Date: {current_date}
|
||||||
User's Location: {location}
|
User's Location: {location}
|
||||||
{username}
|
{username}
|
||||||
|
|
||||||
The JSON schema is of the form {{"codes": ["code1", "code2", "code3"]}}
|
The response JSON schema is of the form {{"code": "<python_code>", "input_files": ["file_path_1", "file_path_2"], "input_links": ["link_1", "link_2"]}}
|
||||||
For example:
|
Examples:
|
||||||
{{"codes": ["print('Hello, World!')", "print('Goodbye, World!')"]}}
|
---
|
||||||
|
{{
|
||||||
|
"code": "# Input values\\nprincipal = 43235\\nrate = 5.24\\nyears = 5\\n\\n# Convert rate to decimal\\nrate_decimal = rate / 100\\n\\n# Calculate final amount\\nfinal_amount = principal * (1 + rate_decimal) ** years\\n\\n# Calculate interest earned\\ninterest_earned = final_amount - principal\\n\\n# Print results with formatting\\nprint(f"Interest Earned: ${{interest_earned:,.2f}}")\\nprint(f"Final Amount: ${{final_amount:,.2f}}")"
|
||||||
|
}}
|
||||||
|
|
||||||
Now it's your turn to construct python programs to answer the user's question. Provide them as a list of strings in a JSON object. Do not say anything else.
|
{{
|
||||||
|
"code": "import re\\n\\n# Read org file\\nfile_path = 'tasks.org'\\nwith open(file_path, 'r') as f:\\n content = f.read()\\n\\n# Get today's date in YYYY-MM-DD format\\ntoday = datetime.now().strftime('%Y-%m-%d')\\npattern = r'\*+\s+.*\\n.*SCHEDULED:\s+<' + today + r'.*>'\\n\\n# Find all matches using multiline mode\\nmatches = re.findall(pattern, content, re.MULTILINE)\\ncount = len(matches)\\n\\n# Display count\\nprint(f'Count of scheduled tasks for today: {{count}}')",
|
||||||
|
"input_files": ["/home/linux/tasks.org"]
|
||||||
|
}}
|
||||||
|
|
||||||
|
{{
|
||||||
|
"code": "import pandas as pd\\nimport matplotlib.pyplot as plt\\n\\n# Load the CSV file\\ndf = pd.read_csv('world_population_by_year.csv')\\n\\n# Plot the data\\nplt.figure(figsize=(10, 6))\\nplt.plot(df['Year'], df['Population'], marker='o')\\n\\n# Add titles and labels\\nplt.title('Population by Year')\\nplt.xlabel('Year')\\nplt.ylabel('Population')\\n\\n# Save the plot to a file\\nplt.savefig('population_by_year_plot.png')",
|
||||||
|
"input_links": ["https://population.un.org/world_population_by_year.csv"]
|
||||||
|
}}
|
||||||
|
|
||||||
|
Now it's your turn to construct a python program to answer the user's question. Provide the code, required input files and input links in a JSON object. Do not say anything else.
|
||||||
Context:
|
Context:
|
||||||
---
|
---
|
||||||
{context}
|
{context}
|
||||||
|
|
|
@ -1,14 +1,14 @@
|
||||||
import asyncio
|
import base64
|
||||||
import datetime
|
import datetime
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import Any, Callable, List, Optional
|
from typing import Any, Callable, List, NamedTuple, Optional
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
|
||||||
from khoj.database.adapters import ais_user_subscribed
|
from khoj.database.adapters import FileObjectAdapters
|
||||||
from khoj.database.models import Agent, KhojUser
|
from khoj.database.models import Agent, FileObject, KhojUser
|
||||||
from khoj.processor.conversation import prompts
|
from khoj.processor.conversation import prompts
|
||||||
from khoj.processor.conversation.utils import (
|
from khoj.processor.conversation.utils import (
|
||||||
ChatEvent,
|
ChatEvent,
|
||||||
|
@ -17,7 +17,7 @@ from khoj.processor.conversation.utils import (
|
||||||
construct_chat_history,
|
construct_chat_history,
|
||||||
)
|
)
|
||||||
from khoj.routers.helpers import send_message_to_model_wrapper
|
from khoj.routers.helpers import send_message_to_model_wrapper
|
||||||
from khoj.utils.helpers import timer
|
from khoj.utils.helpers import is_none_or_empty, timer
|
||||||
from khoj.utils.rawconfig import LocationData
|
from khoj.utils.rawconfig import LocationData
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -26,6 +26,12 @@ logger = logging.getLogger(__name__)
|
||||||
SANDBOX_URL = os.getenv("KHOJ_TERRARIUM_URL", "http://localhost:8080")
|
SANDBOX_URL = os.getenv("KHOJ_TERRARIUM_URL", "http://localhost:8080")
|
||||||
|
|
||||||
|
|
||||||
|
class GeneratedCode(NamedTuple):
|
||||||
|
code: str
|
||||||
|
input_files: List[str]
|
||||||
|
input_links: List[str]
|
||||||
|
|
||||||
|
|
||||||
async def run_code(
|
async def run_code(
|
||||||
query: str,
|
query: str,
|
||||||
conversation_history: dict,
|
conversation_history: dict,
|
||||||
|
@ -41,11 +47,11 @@ async def run_code(
|
||||||
):
|
):
|
||||||
# Generate Code
|
# Generate Code
|
||||||
if send_status_func:
|
if send_status_func:
|
||||||
async for event in send_status_func(f"**Generate code snippets** for {query}"):
|
async for event in send_status_func(f"**Generate code snippet** for {query}"):
|
||||||
yield {ChatEvent.STATUS: event}
|
yield {ChatEvent.STATUS: event}
|
||||||
try:
|
try:
|
||||||
with timer("Chat actor: Generate programs to execute", logger):
|
with timer("Chat actor: Generate programs to execute", logger):
|
||||||
codes = await generate_python_code(
|
generated_code = await generate_python_code(
|
||||||
query,
|
query,
|
||||||
conversation_history,
|
conversation_history,
|
||||||
context,
|
context,
|
||||||
|
@ -59,15 +65,26 @@ async def run_code(
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(f"Failed to generate code for {query} with error: {e}")
|
raise ValueError(f"Failed to generate code for {query} with error: {e}")
|
||||||
|
|
||||||
|
# Prepare Input Data
|
||||||
|
input_data = []
|
||||||
|
user_input_files: List[FileObject] = []
|
||||||
|
for input_file in generated_code.input_files:
|
||||||
|
user_input_files += await FileObjectAdapters.aget_file_objects_by_name(user, input_file)
|
||||||
|
for f in user_input_files:
|
||||||
|
input_data.append(
|
||||||
|
{
|
||||||
|
"filename": os.path.basename(f.file_name),
|
||||||
|
"b64_data": base64.b64encode(f.raw_text.encode("utf-8")).decode("utf-8"),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# Run Code
|
# Run Code
|
||||||
if send_status_func:
|
if send_status_func:
|
||||||
async for event in send_status_func(f"**Running {len(codes)} code snippets**"):
|
async for event in send_status_func(f"**Running code snippet**"):
|
||||||
yield {ChatEvent.STATUS: event}
|
yield {ChatEvent.STATUS: event}
|
||||||
try:
|
try:
|
||||||
tasks = [execute_sandboxed_python(code, sandbox_url) for code in codes]
|
with timer("Chat actor: Execute generated program", logger, log_level=logging.INFO):
|
||||||
with timer("Chat actor: Execute generated programs", logger):
|
result = await execute_sandboxed_python(generated_code.code, input_data, sandbox_url)
|
||||||
results = await asyncio.gather(*tasks)
|
|
||||||
for result in results:
|
|
||||||
code = result.pop("code")
|
code = result.pop("code")
|
||||||
logger.info(f"Executed Code:\n--@@--\n{code}\n--@@--Result:\n--@@--\n{result}\n--@@--")
|
logger.info(f"Executed Code:\n--@@--\n{code}\n--@@--Result:\n--@@--\n{result}\n--@@--")
|
||||||
yield {query: {"code": code, "results": result}}
|
yield {query: {"code": code, "results": result}}
|
||||||
|
@ -81,14 +98,13 @@ async def generate_python_code(
|
||||||
context: str,
|
context: str,
|
||||||
location_data: LocationData,
|
location_data: LocationData,
|
||||||
user: KhojUser,
|
user: KhojUser,
|
||||||
query_images: List[str] = None,
|
query_images: list[str] = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
tracer: dict = {},
|
tracer: dict = {},
|
||||||
query_files: str = None,
|
query_files: str = None,
|
||||||
) -> List[str]:
|
) -> GeneratedCode:
|
||||||
location = f"{location_data}" if location_data else "Unknown"
|
location = f"{location_data}" if location_data else "Unknown"
|
||||||
username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else ""
|
username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else ""
|
||||||
subscribed = await ais_user_subscribed(user)
|
|
||||||
chat_history = construct_chat_history(conversation_history)
|
chat_history = construct_chat_history(conversation_history)
|
||||||
|
|
||||||
utc_date = datetime.datetime.now(datetime.timezone.utc).strftime("%Y-%m-%d")
|
utc_date = datetime.datetime.now(datetime.timezone.utc).strftime("%Y-%m-%d")
|
||||||
|
@ -118,21 +134,26 @@ async def generate_python_code(
|
||||||
# Validate that the response is a non-empty, JSON-serializable list
|
# Validate that the response is a non-empty, JSON-serializable list
|
||||||
response = clean_json(response)
|
response = clean_json(response)
|
||||||
response = json.loads(response)
|
response = json.loads(response)
|
||||||
codes = [code.strip() for code in response["codes"] if code.strip()]
|
code = response.get("code", "").strip()
|
||||||
|
input_files = response.get("input_files", [])
|
||||||
|
input_links = response.get("input_links", [])
|
||||||
|
|
||||||
if not isinstance(codes, list) or not codes or len(codes) == 0:
|
if not isinstance(code, str) or is_none_or_empty(code):
|
||||||
raise ValueError
|
raise ValueError
|
||||||
return codes
|
return GeneratedCode(code, input_files, input_links)
|
||||||
|
|
||||||
|
|
||||||
async def execute_sandboxed_python(code: str, sandbox_url: str = SANDBOX_URL) -> dict[str, Any]:
|
async def execute_sandboxed_python(code: str, input_data: list[dict], sandbox_url: str = SANDBOX_URL) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Takes code to run as a string and calls the terrarium API to execute it.
|
Takes code to run as a string and calls the terrarium API to execute it.
|
||||||
Returns the result of the code execution as a dictionary.
|
Returns the result of the code execution as a dictionary.
|
||||||
|
|
||||||
|
Reference data i/o format based on Terrarium example client code at:
|
||||||
|
https://github.com/cohere-ai/cohere-terrarium/blob/main/example-clients/python/terrarium_client.py
|
||||||
"""
|
"""
|
||||||
headers = {"Content-Type": "application/json"}
|
headers = {"Content-Type": "application/json"}
|
||||||
cleaned_code = clean_code_python(code)
|
cleaned_code = clean_code_python(code)
|
||||||
data = {"code": cleaned_code}
|
data = {"code": cleaned_code, "files": input_data}
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
async with session.post(sandbox_url, json=data, headers=headers) as response:
|
async with session.post(sandbox_url, json=data, headers=headers) as response:
|
||||||
|
|
Loading…
Add table
Reference in a new issue