mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 23:48:56 +01:00
Refactor Run Code tool into separate module and modularize code functions
Move construct_chat_history and ChatEvent enum into conversation.utils and move send_message_to_model_wrapper to conversation.helper to modularize code. And start thinning out the bloated routers.helper - conversation.util components are shared functions that conversation child packages can use. - conversation.helper components can't be imported by conversation packages but it can use these child packages This division allows better modularity while avoiding circular import dependencies
This commit is contained in:
parent
8044733201
commit
a98f97ed5e
5 changed files with 274 additions and 234 deletions
126
src/khoj/processor/conversation/helpers.py
Normal file
126
src/khoj/processor/conversation/helpers.py
Normal file
|
@ -0,0 +1,126 @@
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
from khoj.database.adapters import ConversationAdapters, ais_user_subscribed
|
||||||
|
from khoj.database.models import ChatModelOptions, KhojUser
|
||||||
|
from khoj.processor.conversation.anthropic.anthropic_chat import (
|
||||||
|
anthropic_send_message_to_model,
|
||||||
|
)
|
||||||
|
from khoj.processor.conversation.google.gemini_chat import gemini_send_message_to_model
|
||||||
|
from khoj.processor.conversation.offline.chat_model import send_message_to_model_offline
|
||||||
|
from khoj.processor.conversation.openai.gpt import send_message_to_model
|
||||||
|
from khoj.processor.conversation.utils import generate_chatml_messages_with_context
|
||||||
|
from khoj.utils import state
|
||||||
|
from khoj.utils.config import OfflineChatProcessorModel
|
||||||
|
|
||||||
|
|
||||||
|
async def send_message_to_model_wrapper(
|
||||||
|
message: str,
|
||||||
|
system_message: str = "",
|
||||||
|
response_type: str = "text",
|
||||||
|
chat_model_option: ChatModelOptions = None,
|
||||||
|
subscribed: bool = False,
|
||||||
|
uploaded_image_url: str = None,
|
||||||
|
):
|
||||||
|
conversation_config: ChatModelOptions = (
|
||||||
|
chat_model_option or await ConversationAdapters.aget_default_conversation_config()
|
||||||
|
)
|
||||||
|
|
||||||
|
vision_available = conversation_config.vision_enabled
|
||||||
|
if not vision_available and uploaded_image_url:
|
||||||
|
vision_enabled_config = await ConversationAdapters.aget_vision_enabled_config()
|
||||||
|
if vision_enabled_config:
|
||||||
|
conversation_config = vision_enabled_config
|
||||||
|
vision_available = True
|
||||||
|
|
||||||
|
chat_model = conversation_config.chat_model
|
||||||
|
max_tokens = (
|
||||||
|
conversation_config.subscribed_max_prompt_size
|
||||||
|
if subscribed and conversation_config.subscribed_max_prompt_size
|
||||||
|
else conversation_config.max_prompt_size
|
||||||
|
)
|
||||||
|
tokenizer = conversation_config.tokenizer
|
||||||
|
model_type = conversation_config.model_type
|
||||||
|
vision_available = conversation_config.vision_enabled
|
||||||
|
|
||||||
|
if model_type == ChatModelOptions.ModelType.OFFLINE:
|
||||||
|
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
|
||||||
|
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
|
||||||
|
|
||||||
|
loaded_model = state.offline_chat_processor_config.loaded_model
|
||||||
|
truncated_messages = generate_chatml_messages_with_context(
|
||||||
|
user_message=message,
|
||||||
|
system_message=system_message,
|
||||||
|
model_name=chat_model,
|
||||||
|
loaded_model=loaded_model,
|
||||||
|
tokenizer_name=tokenizer,
|
||||||
|
max_prompt_size=max_tokens,
|
||||||
|
vision_enabled=vision_available,
|
||||||
|
model_type=conversation_config.model_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
return send_message_to_model_offline(
|
||||||
|
messages=truncated_messages,
|
||||||
|
loaded_model=loaded_model,
|
||||||
|
model=chat_model,
|
||||||
|
max_prompt_size=max_tokens,
|
||||||
|
streaming=False,
|
||||||
|
response_type=response_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
elif model_type == ChatModelOptions.ModelType.OPENAI:
|
||||||
|
openai_chat_config = conversation_config.openai_config
|
||||||
|
api_key = openai_chat_config.api_key
|
||||||
|
api_base_url = openai_chat_config.api_base_url
|
||||||
|
truncated_messages = generate_chatml_messages_with_context(
|
||||||
|
user_message=message,
|
||||||
|
system_message=system_message,
|
||||||
|
model_name=chat_model,
|
||||||
|
max_prompt_size=max_tokens,
|
||||||
|
tokenizer_name=tokenizer,
|
||||||
|
vision_enabled=vision_available,
|
||||||
|
uploaded_image_url=uploaded_image_url,
|
||||||
|
model_type=conversation_config.model_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
return send_message_to_model(
|
||||||
|
messages=truncated_messages,
|
||||||
|
api_key=api_key,
|
||||||
|
model=chat_model,
|
||||||
|
response_type=response_type,
|
||||||
|
api_base_url=api_base_url,
|
||||||
|
)
|
||||||
|
elif model_type == ChatModelOptions.ModelType.ANTHROPIC:
|
||||||
|
api_key = conversation_config.openai_config.api_key
|
||||||
|
truncated_messages = generate_chatml_messages_with_context(
|
||||||
|
user_message=message,
|
||||||
|
system_message=system_message,
|
||||||
|
model_name=chat_model,
|
||||||
|
max_prompt_size=max_tokens,
|
||||||
|
tokenizer_name=tokenizer,
|
||||||
|
vision_enabled=vision_available,
|
||||||
|
uploaded_image_url=uploaded_image_url,
|
||||||
|
model_type=conversation_config.model_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
return anthropic_send_message_to_model(
|
||||||
|
messages=truncated_messages,
|
||||||
|
api_key=api_key,
|
||||||
|
model=chat_model,
|
||||||
|
)
|
||||||
|
elif model_type == ChatModelOptions.ModelType.GOOGLE:
|
||||||
|
api_key = conversation_config.openai_config.api_key
|
||||||
|
truncated_messages = generate_chatml_messages_with_context(
|
||||||
|
user_message=message,
|
||||||
|
system_message=system_message,
|
||||||
|
model_name=chat_model,
|
||||||
|
max_prompt_size=max_tokens,
|
||||||
|
tokenizer_name=tokenizer,
|
||||||
|
vision_enabled=vision_available,
|
||||||
|
uploaded_image_url=uploaded_image_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
return gemini_send_message_to_model(
|
||||||
|
messages=truncated_messages, api_key=api_key, model=chat_model, response_type=response_type
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise HTTPException(status_code=500, detail="Invalid conversation config")
|
|
@ -2,6 +2,7 @@ import logging
|
||||||
import math
|
import math
|
||||||
import queue
|
import queue
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from enum import Enum
|
||||||
from time import perf_counter
|
from time import perf_counter
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
@ -10,7 +11,7 @@ from langchain.schema import ChatMessage
|
||||||
from llama_cpp.llama import Llama
|
from llama_cpp.llama import Llama
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
from khoj.database.adapters import ConversationAdapters
|
from khoj.database.adapters import ConversationAdapters, ais_user_subscribed
|
||||||
from khoj.database.models import ChatModelOptions, ClientApplication, KhojUser
|
from khoj.database.models import ChatModelOptions, ClientApplication, KhojUser
|
||||||
from khoj.processor.conversation.offline.utils import download_model, infer_max_tokens
|
from khoj.processor.conversation.offline.utils import download_model, infer_max_tokens
|
||||||
from khoj.utils import state
|
from khoj.utils import state
|
||||||
|
@ -75,6 +76,26 @@ class ThreadedGenerator:
|
||||||
self.queue.put(StopIteration)
|
self.queue.put(StopIteration)
|
||||||
|
|
||||||
|
|
||||||
|
def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="AI") -> str:
|
||||||
|
chat_history = ""
|
||||||
|
for chat in conversation_history.get("chat", [])[-n:]:
|
||||||
|
if chat["by"] == "khoj" and chat["intent"].get("type") in ["remember", "reminder", "summarize"]:
|
||||||
|
chat_history += f"User: {chat['intent']['query']}\n"
|
||||||
|
chat_history += f"{agent_name}: {chat['message']}\n"
|
||||||
|
elif chat["by"] == "khoj" and ("text-to-image" in chat["intent"].get("type")):
|
||||||
|
chat_history += f"User: {chat['intent']['query']}\n"
|
||||||
|
chat_history += f"{agent_name}: [generated image redacted for space]\n"
|
||||||
|
return chat_history
|
||||||
|
|
||||||
|
|
||||||
|
class ChatEvent(Enum):
|
||||||
|
START_LLM_RESPONSE = "start_llm_response"
|
||||||
|
END_LLM_RESPONSE = "end_llm_response"
|
||||||
|
MESSAGE = "message"
|
||||||
|
REFERENCES = "references"
|
||||||
|
STATUS = "status"
|
||||||
|
|
||||||
|
|
||||||
def message_to_log(
|
def message_to_log(
|
||||||
user_message, chat_response, user_message_metadata={}, khoj_message_metadata={}, conversation_log=[]
|
user_message, chat_response, user_message_metadata={}, khoj_message_metadata={}, conversation_log=[]
|
||||||
):
|
):
|
||||||
|
|
122
src/khoj/processor/tools/run_code.py
Normal file
122
src/khoj/processor/tools/run_code.py
Normal file
|
@ -0,0 +1,122 @@
|
||||||
|
import asyncio
|
||||||
|
import datetime
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Any, Callable, List, Optional
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
|
from khoj.database.adapters import ais_user_subscribed
|
||||||
|
from khoj.database.models import Agent, KhojUser
|
||||||
|
from khoj.processor.conversation import prompts
|
||||||
|
from khoj.processor.conversation.helpers import send_message_to_model_wrapper
|
||||||
|
from khoj.processor.conversation.utils import (
|
||||||
|
ChatEvent,
|
||||||
|
construct_chat_history,
|
||||||
|
remove_json_codeblock,
|
||||||
|
)
|
||||||
|
from khoj.utils.helpers import timer
|
||||||
|
from khoj.utils.rawconfig import LocationData
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def run_code(
|
||||||
|
query: str,
|
||||||
|
conversation_history: dict,
|
||||||
|
location_data: LocationData,
|
||||||
|
user: KhojUser,
|
||||||
|
send_status_func: Optional[Callable] = None,
|
||||||
|
uploaded_image_url: str = None,
|
||||||
|
agent: Agent = None,
|
||||||
|
sandbox_url: str = "http://localhost:8080",
|
||||||
|
):
|
||||||
|
# Generate Code
|
||||||
|
if send_status_func:
|
||||||
|
async for event in send_status_func(f"**Generate code snippets** for {query}"):
|
||||||
|
yield {ChatEvent.STATUS: event}
|
||||||
|
try:
|
||||||
|
with timer("Chat actor: Generate programs to execute", logger):
|
||||||
|
codes = await generate_python_code(
|
||||||
|
query, conversation_history, location_data, user, uploaded_image_url, agent
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Failed to generate code for {query} with error: {e}")
|
||||||
|
|
||||||
|
# Run Code
|
||||||
|
if send_status_func:
|
||||||
|
async for event in send_status_func(f"**Running {len(codes)} code snippets**"):
|
||||||
|
yield {ChatEvent.STATUS: event}
|
||||||
|
try:
|
||||||
|
tasks = [execute_sandboxed_python(code, sandbox_url) for code in codes]
|
||||||
|
with timer("Chat actor: Execute generated programs", logger):
|
||||||
|
results = await asyncio.gather(*tasks)
|
||||||
|
for result in results:
|
||||||
|
code = result.pop("code")
|
||||||
|
logger.info(f"Executed Code:\n--@@--\n{code}\n--@@--Result:\n--@@--\n{result}\n--@@--")
|
||||||
|
yield {query: {"code": code, "results": result}}
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Failed to run code for {query} with error: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
async def generate_python_code(
|
||||||
|
q: str,
|
||||||
|
conversation_history: dict,
|
||||||
|
location_data: LocationData,
|
||||||
|
user: KhojUser,
|
||||||
|
uploaded_image_url: str = None,
|
||||||
|
agent: Agent = None,
|
||||||
|
) -> List[str]:
|
||||||
|
location = f"{location_data}" if location_data else "Unknown"
|
||||||
|
username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else ""
|
||||||
|
subscribed = await ais_user_subscribed(user)
|
||||||
|
chat_history = construct_chat_history(conversation_history)
|
||||||
|
|
||||||
|
utc_date = datetime.datetime.now(datetime.timezone.utc).strftime("%Y-%m-%d")
|
||||||
|
personality_context = (
|
||||||
|
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
|
||||||
|
)
|
||||||
|
|
||||||
|
code_generation_prompt = prompts.python_code_generation_prompt.format(
|
||||||
|
current_date=utc_date,
|
||||||
|
query=q,
|
||||||
|
chat_history=chat_history,
|
||||||
|
location=location,
|
||||||
|
username=username,
|
||||||
|
personality_context=personality_context,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await send_message_to_model_wrapper(
|
||||||
|
code_generation_prompt,
|
||||||
|
uploaded_image_url=uploaded_image_url,
|
||||||
|
response_type="json_object",
|
||||||
|
subscribed=subscribed,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate that the response is a non-empty, JSON-serializable list
|
||||||
|
response = response.strip()
|
||||||
|
response = remove_json_codeblock(response)
|
||||||
|
response = json.loads(response)
|
||||||
|
codes = [code.strip() for code in response["codes"] if code.strip()]
|
||||||
|
|
||||||
|
if not isinstance(codes, list) or not codes or len(codes) == 0:
|
||||||
|
raise ValueError
|
||||||
|
return codes
|
||||||
|
|
||||||
|
|
||||||
|
async def execute_sandboxed_python(code: str, sandbox_url: str = "http://localhost:8080") -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Takes code to run as a string and calls the terrarium API to execute it.
|
||||||
|
Returns the result of the code execution as a dictionary.
|
||||||
|
"""
|
||||||
|
headers = {"Content-Type": "application/json"}
|
||||||
|
data = {"code": code}
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(sandbox_url, json=data, headers=headers) as response:
|
||||||
|
if response.status == 200:
|
||||||
|
result: dict[str, Any] = await response.json()
|
||||||
|
result["code"] = code
|
||||||
|
return result
|
||||||
|
else:
|
||||||
|
return {"code": code, "success": False, "std_err": f"Failed to execute code with {response.status}"}
|
|
@ -19,7 +19,6 @@ from khoj.database.adapters import (
|
||||||
AgentAdapters,
|
AgentAdapters,
|
||||||
ConversationAdapters,
|
ConversationAdapters,
|
||||||
EntryAdapters,
|
EntryAdapters,
|
||||||
FileObjectAdapters,
|
|
||||||
PublicConversationAdapters,
|
PublicConversationAdapters,
|
||||||
aget_user_name,
|
aget_user_name,
|
||||||
)
|
)
|
||||||
|
@ -29,6 +28,7 @@ from khoj.processor.conversation.utils import save_to_conversation_log
|
||||||
from khoj.processor.image.generate import text_to_image
|
from khoj.processor.image.generate import text_to_image
|
||||||
from khoj.processor.speech.text_to_speech import generate_text_to_speech
|
from khoj.processor.speech.text_to_speech import generate_text_to_speech
|
||||||
from khoj.processor.tools.online_search import read_webpages, search_online
|
from khoj.processor.tools.online_search import read_webpages, search_online
|
||||||
|
from khoj.processor.tools.run_code import run_code
|
||||||
from khoj.routers.api import extract_references_and_questions
|
from khoj.routers.api import extract_references_and_questions
|
||||||
from khoj.routers.helpers import (
|
from khoj.routers.helpers import (
|
||||||
ApiUserRateLimiter,
|
ApiUserRateLimiter,
|
||||||
|
@ -46,7 +46,6 @@ from khoj.routers.helpers import (
|
||||||
is_query_empty,
|
is_query_empty,
|
||||||
is_ready_to_chat,
|
is_ready_to_chat,
|
||||||
read_chat_stream,
|
read_chat_stream,
|
||||||
run_code,
|
|
||||||
update_telemetry_state,
|
update_telemetry_state,
|
||||||
validate_conversation_config,
|
validate_conversation_config,
|
||||||
)
|
)
|
||||||
|
|
|
@ -24,7 +24,6 @@ from typing import (
|
||||||
)
|
)
|
||||||
from urllib.parse import parse_qs, quote, urljoin, urlparse
|
from urllib.parse import parse_qs, quote, urljoin, urlparse
|
||||||
|
|
||||||
import aiohttp
|
|
||||||
import cron_descriptor
|
import cron_descriptor
|
||||||
import pytz
|
import pytz
|
||||||
import requests
|
import requests
|
||||||
|
@ -79,13 +78,16 @@ from khoj.processor.conversation.google.gemini_chat import (
|
||||||
converse_gemini,
|
converse_gemini,
|
||||||
gemini_send_message_to_model,
|
gemini_send_message_to_model,
|
||||||
)
|
)
|
||||||
|
from khoj.processor.conversation.helpers import send_message_to_model_wrapper
|
||||||
from khoj.processor.conversation.offline.chat_model import (
|
from khoj.processor.conversation.offline.chat_model import (
|
||||||
converse_offline,
|
converse_offline,
|
||||||
send_message_to_model_offline,
|
send_message_to_model_offline,
|
||||||
)
|
)
|
||||||
from khoj.processor.conversation.openai.gpt import converse, send_message_to_model
|
from khoj.processor.conversation.openai.gpt import converse, send_message_to_model
|
||||||
from khoj.processor.conversation.utils import (
|
from khoj.processor.conversation.utils import (
|
||||||
|
ChatEvent,
|
||||||
ThreadedGenerator,
|
ThreadedGenerator,
|
||||||
|
construct_chat_history,
|
||||||
generate_chatml_messages_with_context,
|
generate_chatml_messages_with_context,
|
||||||
remove_json_codeblock,
|
remove_json_codeblock,
|
||||||
save_to_conversation_log,
|
save_to_conversation_log,
|
||||||
|
@ -208,18 +210,6 @@ def get_next_url(request: Request) -> str:
|
||||||
return urljoin(str(request.base_url).rstrip("/"), next_path)
|
return urljoin(str(request.base_url).rstrip("/"), next_path)
|
||||||
|
|
||||||
|
|
||||||
def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="AI") -> str:
|
|
||||||
chat_history = ""
|
|
||||||
for chat in conversation_history.get("chat", [])[-n:]:
|
|
||||||
if chat["by"] == "khoj" and chat["intent"].get("type") in ["remember", "reminder", "summarize"]:
|
|
||||||
chat_history += f"User: {chat['intent']['query']}\n"
|
|
||||||
chat_history += f"{agent_name}: {chat['message']}\n"
|
|
||||||
elif chat["by"] == "khoj" and ("text-to-image" in chat["intent"].get("type")):
|
|
||||||
chat_history += f"User: {chat['intent']['query']}\n"
|
|
||||||
chat_history += f"{agent_name}: [generated image redacted for space]\n"
|
|
||||||
return chat_history
|
|
||||||
|
|
||||||
|
|
||||||
def get_conversation_command(query: str, any_references: bool = False) -> ConversationCommand:
|
def get_conversation_command(query: str, any_references: bool = False) -> ConversationCommand:
|
||||||
if query.startswith("/notes"):
|
if query.startswith("/notes"):
|
||||||
return ConversationCommand.Notes
|
return ConversationCommand.Notes
|
||||||
|
@ -520,103 +510,6 @@ async def generate_online_subqueries(
|
||||||
return [q]
|
return [q]
|
||||||
|
|
||||||
|
|
||||||
async def run_code(
|
|
||||||
query: str,
|
|
||||||
conversation_history: dict,
|
|
||||||
location_data: LocationData,
|
|
||||||
user: KhojUser,
|
|
||||||
send_status_func: Optional[Callable] = None,
|
|
||||||
uploaded_image_url: str = None,
|
|
||||||
agent: Agent = None,
|
|
||||||
sandbox_url: str = "http://localhost:8080",
|
|
||||||
):
|
|
||||||
# Generate Code
|
|
||||||
if send_status_func:
|
|
||||||
async for event in send_status_func(f"**Generate code snippets** for {query}"):
|
|
||||||
yield {ChatEvent.STATUS: event}
|
|
||||||
try:
|
|
||||||
with timer("Chat actor: Generate programs to execute", logger):
|
|
||||||
codes = await generate_python_code(
|
|
||||||
query, conversation_history, location_data, user, uploaded_image_url, agent
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
raise ValueError(f"Failed to generate code for {query} with error: {e}")
|
|
||||||
|
|
||||||
# Run Code
|
|
||||||
if send_status_func:
|
|
||||||
async for event in send_status_func(f"**Running {len(codes)} code snippets**"):
|
|
||||||
yield {ChatEvent.STATUS: event}
|
|
||||||
try:
|
|
||||||
tasks = [execute_sandboxed_python(code, sandbox_url) for code in codes]
|
|
||||||
with timer("Chat actor: Execute generated programs", logger):
|
|
||||||
results = await asyncio.gather(*tasks)
|
|
||||||
for result in results:
|
|
||||||
code = result.pop("code")
|
|
||||||
logger.info(f"Executed Code:\n--@@--\n{code}\n--@@--Result:\n--@@--\n{result}\n--@@--")
|
|
||||||
yield {query: {"code": code, "results": result}}
|
|
||||||
except Exception as e:
|
|
||||||
raise ValueError(f"Failed to run code for {query} with error: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
async def generate_python_code(
|
|
||||||
q: str,
|
|
||||||
conversation_history: dict,
|
|
||||||
location_data: LocationData,
|
|
||||||
user: KhojUser,
|
|
||||||
uploaded_image_url: str = None,
|
|
||||||
agent: Agent = None,
|
|
||||||
) -> List[str]:
|
|
||||||
location = f"{location_data}" if location_data else "Unknown"
|
|
||||||
username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else ""
|
|
||||||
chat_history = construct_chat_history(conversation_history)
|
|
||||||
|
|
||||||
utc_date = datetime.utcnow().strftime("%Y-%m-%d")
|
|
||||||
personality_context = (
|
|
||||||
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
|
|
||||||
)
|
|
||||||
|
|
||||||
code_generation_prompt = prompts.python_code_generation_prompt.format(
|
|
||||||
current_date=utc_date,
|
|
||||||
query=q,
|
|
||||||
chat_history=chat_history,
|
|
||||||
location=location,
|
|
||||||
username=username,
|
|
||||||
personality_context=personality_context,
|
|
||||||
)
|
|
||||||
|
|
||||||
response = await send_message_to_model_wrapper(
|
|
||||||
code_generation_prompt, uploaded_image_url=uploaded_image_url, response_type="json_object", user=user
|
|
||||||
)
|
|
||||||
|
|
||||||
# Validate that the response is a non-empty, JSON-serializable list
|
|
||||||
response = response.strip()
|
|
||||||
response = remove_json_codeblock(response)
|
|
||||||
response = json.loads(response)
|
|
||||||
codes = [code.strip() for code in response["codes"] if code.strip()]
|
|
||||||
|
|
||||||
if not isinstance(codes, list) or not codes or len(codes) == 0:
|
|
||||||
raise ValueError
|
|
||||||
return codes
|
|
||||||
|
|
||||||
|
|
||||||
async def execute_sandboxed_python(code: str, sandbox_url: str = "http://localhost:8080") -> dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Takes code to run as a string and calls the terrarium API to execute it.
|
|
||||||
Returns the result of the code execution as a dictionary.
|
|
||||||
"""
|
|
||||||
headers = {"Content-Type": "application/json"}
|
|
||||||
data = {"code": code}
|
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
async with session.post(sandbox_url, json=data, headers=headers) as response:
|
|
||||||
if response.status == 200:
|
|
||||||
result: dict[str, Any] = await response.json()
|
|
||||||
result["code"] = code
|
|
||||||
return result
|
|
||||||
else:
|
|
||||||
return {"code": code, "success": False, "std_err": f"Failed to execute code with {response.status}"}
|
|
||||||
|
|
||||||
|
|
||||||
async def schedule_query(q: str, conversation_history: dict, uploaded_image_url: str = None) -> Tuple[str, ...]:
|
async def schedule_query(q: str, conversation_history: dict, uploaded_image_url: str = None) -> Tuple[str, ...]:
|
||||||
"""
|
"""
|
||||||
Schedule the date, time to run the query. Assume the server timezone is UTC.
|
Schedule the date, time to run the query. Assume the server timezone is UTC.
|
||||||
|
@ -837,119 +730,6 @@ async def generate_better_image_prompt(
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
async def send_message_to_model_wrapper(
|
|
||||||
message: str,
|
|
||||||
system_message: str = "",
|
|
||||||
response_type: str = "text",
|
|
||||||
chat_model_option: ChatModelOptions = None,
|
|
||||||
subscribed: bool = False,
|
|
||||||
uploaded_image_url: str = None,
|
|
||||||
):
|
|
||||||
conversation_config: ChatModelOptions = (
|
|
||||||
chat_model_option or await ConversationAdapters.aget_default_conversation_config()
|
|
||||||
)
|
|
||||||
|
|
||||||
vision_available = conversation_config.vision_enabled
|
|
||||||
if not vision_available and uploaded_image_url:
|
|
||||||
vision_enabled_config = await ConversationAdapters.aget_vision_enabled_config()
|
|
||||||
if vision_enabled_config:
|
|
||||||
conversation_config = vision_enabled_config
|
|
||||||
vision_available = True
|
|
||||||
|
|
||||||
chat_model = conversation_config.chat_model
|
|
||||||
max_tokens = (
|
|
||||||
conversation_config.subscribed_max_prompt_size
|
|
||||||
if subscribed and conversation_config.subscribed_max_prompt_size
|
|
||||||
else conversation_config.max_prompt_size
|
|
||||||
)
|
|
||||||
tokenizer = conversation_config.tokenizer
|
|
||||||
model_type = conversation_config.model_type
|
|
||||||
vision_available = conversation_config.vision_enabled
|
|
||||||
|
|
||||||
if model_type == ChatModelOptions.ModelType.OFFLINE:
|
|
||||||
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
|
|
||||||
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
|
|
||||||
|
|
||||||
loaded_model = state.offline_chat_processor_config.loaded_model
|
|
||||||
truncated_messages = generate_chatml_messages_with_context(
|
|
||||||
user_message=message,
|
|
||||||
system_message=system_message,
|
|
||||||
model_name=chat_model,
|
|
||||||
loaded_model=loaded_model,
|
|
||||||
tokenizer_name=tokenizer,
|
|
||||||
max_prompt_size=max_tokens,
|
|
||||||
vision_enabled=vision_available,
|
|
||||||
model_type=conversation_config.model_type,
|
|
||||||
)
|
|
||||||
|
|
||||||
return send_message_to_model_offline(
|
|
||||||
messages=truncated_messages,
|
|
||||||
loaded_model=loaded_model,
|
|
||||||
model=chat_model,
|
|
||||||
max_prompt_size=max_tokens,
|
|
||||||
streaming=False,
|
|
||||||
response_type=response_type,
|
|
||||||
)
|
|
||||||
|
|
||||||
elif model_type == ChatModelOptions.ModelType.OPENAI:
|
|
||||||
openai_chat_config = conversation_config.openai_config
|
|
||||||
api_key = openai_chat_config.api_key
|
|
||||||
api_base_url = openai_chat_config.api_base_url
|
|
||||||
truncated_messages = generate_chatml_messages_with_context(
|
|
||||||
user_message=message,
|
|
||||||
system_message=system_message,
|
|
||||||
model_name=chat_model,
|
|
||||||
max_prompt_size=max_tokens,
|
|
||||||
tokenizer_name=tokenizer,
|
|
||||||
vision_enabled=vision_available,
|
|
||||||
uploaded_image_url=uploaded_image_url,
|
|
||||||
model_type=conversation_config.model_type,
|
|
||||||
)
|
|
||||||
|
|
||||||
return send_message_to_model(
|
|
||||||
messages=truncated_messages,
|
|
||||||
api_key=api_key,
|
|
||||||
model=chat_model,
|
|
||||||
response_type=response_type,
|
|
||||||
api_base_url=api_base_url,
|
|
||||||
)
|
|
||||||
elif model_type == ChatModelOptions.ModelType.ANTHROPIC:
|
|
||||||
api_key = conversation_config.openai_config.api_key
|
|
||||||
truncated_messages = generate_chatml_messages_with_context(
|
|
||||||
user_message=message,
|
|
||||||
system_message=system_message,
|
|
||||||
model_name=chat_model,
|
|
||||||
max_prompt_size=max_tokens,
|
|
||||||
tokenizer_name=tokenizer,
|
|
||||||
vision_enabled=vision_available,
|
|
||||||
uploaded_image_url=uploaded_image_url,
|
|
||||||
model_type=conversation_config.model_type,
|
|
||||||
)
|
|
||||||
|
|
||||||
return anthropic_send_message_to_model(
|
|
||||||
messages=truncated_messages,
|
|
||||||
api_key=api_key,
|
|
||||||
model=chat_model,
|
|
||||||
)
|
|
||||||
elif model_type == ChatModelOptions.ModelType.GOOGLE:
|
|
||||||
api_key = conversation_config.openai_config.api_key
|
|
||||||
truncated_messages = generate_chatml_messages_with_context(
|
|
||||||
user_message=message,
|
|
||||||
system_message=system_message,
|
|
||||||
model_name=chat_model,
|
|
||||||
max_prompt_size=max_tokens,
|
|
||||||
tokenizer_name=tokenizer,
|
|
||||||
vision_enabled=vision_available,
|
|
||||||
uploaded_image_url=uploaded_image_url,
|
|
||||||
)
|
|
||||||
|
|
||||||
return gemini_send_message_to_model(
|
|
||||||
messages=truncated_messages, api_key=api_key, model=chat_model, response_type=response_type
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise HTTPException(status_code=500, detail="Invalid conversation config")
|
|
||||||
|
|
||||||
|
|
||||||
def send_message_to_model_wrapper_sync(
|
def send_message_to_model_wrapper_sync(
|
||||||
message: str,
|
message: str,
|
||||||
system_message: str = "",
|
system_message: str = "",
|
||||||
|
@ -1540,14 +1320,6 @@ Manage your automations [here](/automations).
|
||||||
""".strip()
|
""".strip()
|
||||||
|
|
||||||
|
|
||||||
class ChatEvent(Enum):
|
|
||||||
START_LLM_RESPONSE = "start_llm_response"
|
|
||||||
END_LLM_RESPONSE = "end_llm_response"
|
|
||||||
MESSAGE = "message"
|
|
||||||
REFERENCES = "references"
|
|
||||||
STATUS = "status"
|
|
||||||
|
|
||||||
|
|
||||||
class MessageProcessor:
|
class MessageProcessor:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.references = {}
|
self.references = {}
|
||||||
|
|
Loading…
Reference in a new issue