Refactor Run Code tool into separate module and modularize code functions

Move construct_chat_history and ChatEvent enum into conversation.utils
and move send_message_to_model_wrapper to conversation.helper to
modularize code. And start thinning out the bloated routers.helper

- conversation.util components are shared functions that conversation
  child packages can use.
- conversation.helper components can't be imported by conversation
  packages but it can use these child packages

This division allows better modularity while avoiding circular
import dependencies
This commit is contained in:
Debanjum Singh Solanky 2024-10-09 15:54:54 -07:00
parent 8044733201
commit a98f97ed5e
5 changed files with 274 additions and 234 deletions

View file

@ -0,0 +1,126 @@
from fastapi import HTTPException
from khoj.database.adapters import ConversationAdapters, ais_user_subscribed
from khoj.database.models import ChatModelOptions, KhojUser
from khoj.processor.conversation.anthropic.anthropic_chat import (
anthropic_send_message_to_model,
)
from khoj.processor.conversation.google.gemini_chat import gemini_send_message_to_model
from khoj.processor.conversation.offline.chat_model import send_message_to_model_offline
from khoj.processor.conversation.openai.gpt import send_message_to_model
from khoj.processor.conversation.utils import generate_chatml_messages_with_context
from khoj.utils import state
from khoj.utils.config import OfflineChatProcessorModel
async def send_message_to_model_wrapper(
message: str,
system_message: str = "",
response_type: str = "text",
chat_model_option: ChatModelOptions = None,
subscribed: bool = False,
uploaded_image_url: str = None,
):
conversation_config: ChatModelOptions = (
chat_model_option or await ConversationAdapters.aget_default_conversation_config()
)
vision_available = conversation_config.vision_enabled
if not vision_available and uploaded_image_url:
vision_enabled_config = await ConversationAdapters.aget_vision_enabled_config()
if vision_enabled_config:
conversation_config = vision_enabled_config
vision_available = True
chat_model = conversation_config.chat_model
max_tokens = (
conversation_config.subscribed_max_prompt_size
if subscribed and conversation_config.subscribed_max_prompt_size
else conversation_config.max_prompt_size
)
tokenizer = conversation_config.tokenizer
model_type = conversation_config.model_type
vision_available = conversation_config.vision_enabled
if model_type == ChatModelOptions.ModelType.OFFLINE:
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
loaded_model = state.offline_chat_processor_config.loaded_model
truncated_messages = generate_chatml_messages_with_context(
user_message=message,
system_message=system_message,
model_name=chat_model,
loaded_model=loaded_model,
tokenizer_name=tokenizer,
max_prompt_size=max_tokens,
vision_enabled=vision_available,
model_type=conversation_config.model_type,
)
return send_message_to_model_offline(
messages=truncated_messages,
loaded_model=loaded_model,
model=chat_model,
max_prompt_size=max_tokens,
streaming=False,
response_type=response_type,
)
elif model_type == ChatModelOptions.ModelType.OPENAI:
openai_chat_config = conversation_config.openai_config
api_key = openai_chat_config.api_key
api_base_url = openai_chat_config.api_base_url
truncated_messages = generate_chatml_messages_with_context(
user_message=message,
system_message=system_message,
model_name=chat_model,
max_prompt_size=max_tokens,
tokenizer_name=tokenizer,
vision_enabled=vision_available,
uploaded_image_url=uploaded_image_url,
model_type=conversation_config.model_type,
)
return send_message_to_model(
messages=truncated_messages,
api_key=api_key,
model=chat_model,
response_type=response_type,
api_base_url=api_base_url,
)
elif model_type == ChatModelOptions.ModelType.ANTHROPIC:
api_key = conversation_config.openai_config.api_key
truncated_messages = generate_chatml_messages_with_context(
user_message=message,
system_message=system_message,
model_name=chat_model,
max_prompt_size=max_tokens,
tokenizer_name=tokenizer,
vision_enabled=vision_available,
uploaded_image_url=uploaded_image_url,
model_type=conversation_config.model_type,
)
return anthropic_send_message_to_model(
messages=truncated_messages,
api_key=api_key,
model=chat_model,
)
elif model_type == ChatModelOptions.ModelType.GOOGLE:
api_key = conversation_config.openai_config.api_key
truncated_messages = generate_chatml_messages_with_context(
user_message=message,
system_message=system_message,
model_name=chat_model,
max_prompt_size=max_tokens,
tokenizer_name=tokenizer,
vision_enabled=vision_available,
uploaded_image_url=uploaded_image_url,
)
return gemini_send_message_to_model(
messages=truncated_messages, api_key=api_key, model=chat_model, response_type=response_type
)
else:
raise HTTPException(status_code=500, detail="Invalid conversation config")

View file

@ -2,6 +2,7 @@ import logging
import math
import queue
from datetime import datetime
from enum import Enum
from time import perf_counter
from typing import Any, Dict, List, Optional
@ -10,7 +11,7 @@ from langchain.schema import ChatMessage
from llama_cpp.llama import Llama
from transformers import AutoTokenizer
from khoj.database.adapters import ConversationAdapters
from khoj.database.adapters import ConversationAdapters, ais_user_subscribed
from khoj.database.models import ChatModelOptions, ClientApplication, KhojUser
from khoj.processor.conversation.offline.utils import download_model, infer_max_tokens
from khoj.utils import state
@ -75,6 +76,26 @@ class ThreadedGenerator:
self.queue.put(StopIteration)
def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="AI") -> str:
chat_history = ""
for chat in conversation_history.get("chat", [])[-n:]:
if chat["by"] == "khoj" and chat["intent"].get("type") in ["remember", "reminder", "summarize"]:
chat_history += f"User: {chat['intent']['query']}\n"
chat_history += f"{agent_name}: {chat['message']}\n"
elif chat["by"] == "khoj" and ("text-to-image" in chat["intent"].get("type")):
chat_history += f"User: {chat['intent']['query']}\n"
chat_history += f"{agent_name}: [generated image redacted for space]\n"
return chat_history
class ChatEvent(Enum):
START_LLM_RESPONSE = "start_llm_response"
END_LLM_RESPONSE = "end_llm_response"
MESSAGE = "message"
REFERENCES = "references"
STATUS = "status"
def message_to_log(
user_message, chat_response, user_message_metadata={}, khoj_message_metadata={}, conversation_log=[]
):

View file

@ -0,0 +1,122 @@
import asyncio
import datetime
import json
import logging
from typing import Any, Callable, List, Optional
import aiohttp
from khoj.database.adapters import ais_user_subscribed
from khoj.database.models import Agent, KhojUser
from khoj.processor.conversation import prompts
from khoj.processor.conversation.helpers import send_message_to_model_wrapper
from khoj.processor.conversation.utils import (
ChatEvent,
construct_chat_history,
remove_json_codeblock,
)
from khoj.utils.helpers import timer
from khoj.utils.rawconfig import LocationData
logger = logging.getLogger(__name__)
async def run_code(
query: str,
conversation_history: dict,
location_data: LocationData,
user: KhojUser,
send_status_func: Optional[Callable] = None,
uploaded_image_url: str = None,
agent: Agent = None,
sandbox_url: str = "http://localhost:8080",
):
# Generate Code
if send_status_func:
async for event in send_status_func(f"**Generate code snippets** for {query}"):
yield {ChatEvent.STATUS: event}
try:
with timer("Chat actor: Generate programs to execute", logger):
codes = await generate_python_code(
query, conversation_history, location_data, user, uploaded_image_url, agent
)
except Exception as e:
raise ValueError(f"Failed to generate code for {query} with error: {e}")
# Run Code
if send_status_func:
async for event in send_status_func(f"**Running {len(codes)} code snippets**"):
yield {ChatEvent.STATUS: event}
try:
tasks = [execute_sandboxed_python(code, sandbox_url) for code in codes]
with timer("Chat actor: Execute generated programs", logger):
results = await asyncio.gather(*tasks)
for result in results:
code = result.pop("code")
logger.info(f"Executed Code:\n--@@--\n{code}\n--@@--Result:\n--@@--\n{result}\n--@@--")
yield {query: {"code": code, "results": result}}
except Exception as e:
raise ValueError(f"Failed to run code for {query} with error: {e}")
async def generate_python_code(
q: str,
conversation_history: dict,
location_data: LocationData,
user: KhojUser,
uploaded_image_url: str = None,
agent: Agent = None,
) -> List[str]:
location = f"{location_data}" if location_data else "Unknown"
username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else ""
subscribed = await ais_user_subscribed(user)
chat_history = construct_chat_history(conversation_history)
utc_date = datetime.datetime.now(datetime.timezone.utc).strftime("%Y-%m-%d")
personality_context = (
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
)
code_generation_prompt = prompts.python_code_generation_prompt.format(
current_date=utc_date,
query=q,
chat_history=chat_history,
location=location,
username=username,
personality_context=personality_context,
)
response = await send_message_to_model_wrapper(
code_generation_prompt,
uploaded_image_url=uploaded_image_url,
response_type="json_object",
subscribed=subscribed,
)
# Validate that the response is a non-empty, JSON-serializable list
response = response.strip()
response = remove_json_codeblock(response)
response = json.loads(response)
codes = [code.strip() for code in response["codes"] if code.strip()]
if not isinstance(codes, list) or not codes or len(codes) == 0:
raise ValueError
return codes
async def execute_sandboxed_python(code: str, sandbox_url: str = "http://localhost:8080") -> dict[str, Any]:
"""
Takes code to run as a string and calls the terrarium API to execute it.
Returns the result of the code execution as a dictionary.
"""
headers = {"Content-Type": "application/json"}
data = {"code": code}
async with aiohttp.ClientSession() as session:
async with session.post(sandbox_url, json=data, headers=headers) as response:
if response.status == 200:
result: dict[str, Any] = await response.json()
result["code"] = code
return result
else:
return {"code": code, "success": False, "std_err": f"Failed to execute code with {response.status}"}

View file

@ -19,7 +19,6 @@ from khoj.database.adapters import (
AgentAdapters,
ConversationAdapters,
EntryAdapters,
FileObjectAdapters,
PublicConversationAdapters,
aget_user_name,
)
@ -29,6 +28,7 @@ from khoj.processor.conversation.utils import save_to_conversation_log
from khoj.processor.image.generate import text_to_image
from khoj.processor.speech.text_to_speech import generate_text_to_speech
from khoj.processor.tools.online_search import read_webpages, search_online
from khoj.processor.tools.run_code import run_code
from khoj.routers.api import extract_references_and_questions
from khoj.routers.helpers import (
ApiUserRateLimiter,
@ -46,7 +46,6 @@ from khoj.routers.helpers import (
is_query_empty,
is_ready_to_chat,
read_chat_stream,
run_code,
update_telemetry_state,
validate_conversation_config,
)

View file

@ -24,7 +24,6 @@ from typing import (
)
from urllib.parse import parse_qs, quote, urljoin, urlparse
import aiohttp
import cron_descriptor
import pytz
import requests
@ -79,13 +78,16 @@ from khoj.processor.conversation.google.gemini_chat import (
converse_gemini,
gemini_send_message_to_model,
)
from khoj.processor.conversation.helpers import send_message_to_model_wrapper
from khoj.processor.conversation.offline.chat_model import (
converse_offline,
send_message_to_model_offline,
)
from khoj.processor.conversation.openai.gpt import converse, send_message_to_model
from khoj.processor.conversation.utils import (
ChatEvent,
ThreadedGenerator,
construct_chat_history,
generate_chatml_messages_with_context,
remove_json_codeblock,
save_to_conversation_log,
@ -208,18 +210,6 @@ def get_next_url(request: Request) -> str:
return urljoin(str(request.base_url).rstrip("/"), next_path)
def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="AI") -> str:
chat_history = ""
for chat in conversation_history.get("chat", [])[-n:]:
if chat["by"] == "khoj" and chat["intent"].get("type") in ["remember", "reminder", "summarize"]:
chat_history += f"User: {chat['intent']['query']}\n"
chat_history += f"{agent_name}: {chat['message']}\n"
elif chat["by"] == "khoj" and ("text-to-image" in chat["intent"].get("type")):
chat_history += f"User: {chat['intent']['query']}\n"
chat_history += f"{agent_name}: [generated image redacted for space]\n"
return chat_history
def get_conversation_command(query: str, any_references: bool = False) -> ConversationCommand:
if query.startswith("/notes"):
return ConversationCommand.Notes
@ -520,103 +510,6 @@ async def generate_online_subqueries(
return [q]
async def run_code(
query: str,
conversation_history: dict,
location_data: LocationData,
user: KhojUser,
send_status_func: Optional[Callable] = None,
uploaded_image_url: str = None,
agent: Agent = None,
sandbox_url: str = "http://localhost:8080",
):
# Generate Code
if send_status_func:
async for event in send_status_func(f"**Generate code snippets** for {query}"):
yield {ChatEvent.STATUS: event}
try:
with timer("Chat actor: Generate programs to execute", logger):
codes = await generate_python_code(
query, conversation_history, location_data, user, uploaded_image_url, agent
)
except Exception as e:
raise ValueError(f"Failed to generate code for {query} with error: {e}")
# Run Code
if send_status_func:
async for event in send_status_func(f"**Running {len(codes)} code snippets**"):
yield {ChatEvent.STATUS: event}
try:
tasks = [execute_sandboxed_python(code, sandbox_url) for code in codes]
with timer("Chat actor: Execute generated programs", logger):
results = await asyncio.gather(*tasks)
for result in results:
code = result.pop("code")
logger.info(f"Executed Code:\n--@@--\n{code}\n--@@--Result:\n--@@--\n{result}\n--@@--")
yield {query: {"code": code, "results": result}}
except Exception as e:
raise ValueError(f"Failed to run code for {query} with error: {e}")
async def generate_python_code(
q: str,
conversation_history: dict,
location_data: LocationData,
user: KhojUser,
uploaded_image_url: str = None,
agent: Agent = None,
) -> List[str]:
location = f"{location_data}" if location_data else "Unknown"
username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else ""
chat_history = construct_chat_history(conversation_history)
utc_date = datetime.utcnow().strftime("%Y-%m-%d")
personality_context = (
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
)
code_generation_prompt = prompts.python_code_generation_prompt.format(
current_date=utc_date,
query=q,
chat_history=chat_history,
location=location,
username=username,
personality_context=personality_context,
)
response = await send_message_to_model_wrapper(
code_generation_prompt, uploaded_image_url=uploaded_image_url, response_type="json_object", user=user
)
# Validate that the response is a non-empty, JSON-serializable list
response = response.strip()
response = remove_json_codeblock(response)
response = json.loads(response)
codes = [code.strip() for code in response["codes"] if code.strip()]
if not isinstance(codes, list) or not codes or len(codes) == 0:
raise ValueError
return codes
async def execute_sandboxed_python(code: str, sandbox_url: str = "http://localhost:8080") -> dict[str, Any]:
"""
Takes code to run as a string and calls the terrarium API to execute it.
Returns the result of the code execution as a dictionary.
"""
headers = {"Content-Type": "application/json"}
data = {"code": code}
async with aiohttp.ClientSession() as session:
async with session.post(sandbox_url, json=data, headers=headers) as response:
if response.status == 200:
result: dict[str, Any] = await response.json()
result["code"] = code
return result
else:
return {"code": code, "success": False, "std_err": f"Failed to execute code with {response.status}"}
async def schedule_query(q: str, conversation_history: dict, uploaded_image_url: str = None) -> Tuple[str, ...]:
"""
Schedule the date, time to run the query. Assume the server timezone is UTC.
@ -837,119 +730,6 @@ async def generate_better_image_prompt(
return response
async def send_message_to_model_wrapper(
message: str,
system_message: str = "",
response_type: str = "text",
chat_model_option: ChatModelOptions = None,
subscribed: bool = False,
uploaded_image_url: str = None,
):
conversation_config: ChatModelOptions = (
chat_model_option or await ConversationAdapters.aget_default_conversation_config()
)
vision_available = conversation_config.vision_enabled
if not vision_available and uploaded_image_url:
vision_enabled_config = await ConversationAdapters.aget_vision_enabled_config()
if vision_enabled_config:
conversation_config = vision_enabled_config
vision_available = True
chat_model = conversation_config.chat_model
max_tokens = (
conversation_config.subscribed_max_prompt_size
if subscribed and conversation_config.subscribed_max_prompt_size
else conversation_config.max_prompt_size
)
tokenizer = conversation_config.tokenizer
model_type = conversation_config.model_type
vision_available = conversation_config.vision_enabled
if model_type == ChatModelOptions.ModelType.OFFLINE:
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
loaded_model = state.offline_chat_processor_config.loaded_model
truncated_messages = generate_chatml_messages_with_context(
user_message=message,
system_message=system_message,
model_name=chat_model,
loaded_model=loaded_model,
tokenizer_name=tokenizer,
max_prompt_size=max_tokens,
vision_enabled=vision_available,
model_type=conversation_config.model_type,
)
return send_message_to_model_offline(
messages=truncated_messages,
loaded_model=loaded_model,
model=chat_model,
max_prompt_size=max_tokens,
streaming=False,
response_type=response_type,
)
elif model_type == ChatModelOptions.ModelType.OPENAI:
openai_chat_config = conversation_config.openai_config
api_key = openai_chat_config.api_key
api_base_url = openai_chat_config.api_base_url
truncated_messages = generate_chatml_messages_with_context(
user_message=message,
system_message=system_message,
model_name=chat_model,
max_prompt_size=max_tokens,
tokenizer_name=tokenizer,
vision_enabled=vision_available,
uploaded_image_url=uploaded_image_url,
model_type=conversation_config.model_type,
)
return send_message_to_model(
messages=truncated_messages,
api_key=api_key,
model=chat_model,
response_type=response_type,
api_base_url=api_base_url,
)
elif model_type == ChatModelOptions.ModelType.ANTHROPIC:
api_key = conversation_config.openai_config.api_key
truncated_messages = generate_chatml_messages_with_context(
user_message=message,
system_message=system_message,
model_name=chat_model,
max_prompt_size=max_tokens,
tokenizer_name=tokenizer,
vision_enabled=vision_available,
uploaded_image_url=uploaded_image_url,
model_type=conversation_config.model_type,
)
return anthropic_send_message_to_model(
messages=truncated_messages,
api_key=api_key,
model=chat_model,
)
elif model_type == ChatModelOptions.ModelType.GOOGLE:
api_key = conversation_config.openai_config.api_key
truncated_messages = generate_chatml_messages_with_context(
user_message=message,
system_message=system_message,
model_name=chat_model,
max_prompt_size=max_tokens,
tokenizer_name=tokenizer,
vision_enabled=vision_available,
uploaded_image_url=uploaded_image_url,
)
return gemini_send_message_to_model(
messages=truncated_messages, api_key=api_key, model=chat_model, response_type=response_type
)
else:
raise HTTPException(status_code=500, detail="Invalid conversation config")
def send_message_to_model_wrapper_sync(
message: str,
system_message: str = "",
@ -1540,14 +1320,6 @@ Manage your automations [here](/automations).
""".strip()
class ChatEvent(Enum):
START_LLM_RESPONSE = "start_llm_response"
END_LLM_RESPONSE = "end_llm_response"
MESSAGE = "message"
REFERENCES = "references"
STATUS = "status"
class MessageProcessor:
def __init__(self):
self.references = {}