diff --git a/src/khoj/processor/conversation/helpers.py b/src/khoj/processor/conversation/helpers.py new file mode 100644 index 00000000..06a8557c --- /dev/null +++ b/src/khoj/processor/conversation/helpers.py @@ -0,0 +1,126 @@ +from fastapi import HTTPException + +from khoj.database.adapters import ConversationAdapters, ais_user_subscribed +from khoj.database.models import ChatModelOptions, KhojUser +from khoj.processor.conversation.anthropic.anthropic_chat import ( + anthropic_send_message_to_model, +) +from khoj.processor.conversation.google.gemini_chat import gemini_send_message_to_model +from khoj.processor.conversation.offline.chat_model import send_message_to_model_offline +from khoj.processor.conversation.openai.gpt import send_message_to_model +from khoj.processor.conversation.utils import generate_chatml_messages_with_context +from khoj.utils import state +from khoj.utils.config import OfflineChatProcessorModel + + +async def send_message_to_model_wrapper( + message: str, + system_message: str = "", + response_type: str = "text", + chat_model_option: ChatModelOptions = None, + subscribed: bool = False, + uploaded_image_url: str = None, +): + conversation_config: ChatModelOptions = ( + chat_model_option or await ConversationAdapters.aget_default_conversation_config() + ) + + vision_available = conversation_config.vision_enabled + if not vision_available and uploaded_image_url: + vision_enabled_config = await ConversationAdapters.aget_vision_enabled_config() + if vision_enabled_config: + conversation_config = vision_enabled_config + vision_available = True + + chat_model = conversation_config.chat_model + max_tokens = ( + conversation_config.subscribed_max_prompt_size + if subscribed and conversation_config.subscribed_max_prompt_size + else conversation_config.max_prompt_size + ) + tokenizer = conversation_config.tokenizer + model_type = conversation_config.model_type + vision_available = conversation_config.vision_enabled + + if model_type == ChatModelOptions.ModelType.OFFLINE: + if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None: + state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens) + + loaded_model = state.offline_chat_processor_config.loaded_model + truncated_messages = generate_chatml_messages_with_context( + user_message=message, + system_message=system_message, + model_name=chat_model, + loaded_model=loaded_model, + tokenizer_name=tokenizer, + max_prompt_size=max_tokens, + vision_enabled=vision_available, + model_type=conversation_config.model_type, + ) + + return send_message_to_model_offline( + messages=truncated_messages, + loaded_model=loaded_model, + model=chat_model, + max_prompt_size=max_tokens, + streaming=False, + response_type=response_type, + ) + + elif model_type == ChatModelOptions.ModelType.OPENAI: + openai_chat_config = conversation_config.openai_config + api_key = openai_chat_config.api_key + api_base_url = openai_chat_config.api_base_url + truncated_messages = generate_chatml_messages_with_context( + user_message=message, + system_message=system_message, + model_name=chat_model, + max_prompt_size=max_tokens, + tokenizer_name=tokenizer, + vision_enabled=vision_available, + uploaded_image_url=uploaded_image_url, + model_type=conversation_config.model_type, + ) + + return send_message_to_model( + messages=truncated_messages, + api_key=api_key, + model=chat_model, + response_type=response_type, + api_base_url=api_base_url, + ) + elif model_type == ChatModelOptions.ModelType.ANTHROPIC: + api_key = conversation_config.openai_config.api_key + truncated_messages = generate_chatml_messages_with_context( + user_message=message, + system_message=system_message, + model_name=chat_model, + max_prompt_size=max_tokens, + tokenizer_name=tokenizer, + vision_enabled=vision_available, + uploaded_image_url=uploaded_image_url, + model_type=conversation_config.model_type, + ) + + return anthropic_send_message_to_model( + messages=truncated_messages, + api_key=api_key, + model=chat_model, + ) + elif model_type == ChatModelOptions.ModelType.GOOGLE: + api_key = conversation_config.openai_config.api_key + truncated_messages = generate_chatml_messages_with_context( + user_message=message, + system_message=system_message, + model_name=chat_model, + max_prompt_size=max_tokens, + tokenizer_name=tokenizer, + vision_enabled=vision_available, + uploaded_image_url=uploaded_image_url, + ) + + return gemini_send_message_to_model( + messages=truncated_messages, api_key=api_key, model=chat_model, response_type=response_type + ) + else: + raise HTTPException(status_code=500, detail="Invalid conversation config") diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 9a2ba230..00ef56d9 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -2,6 +2,7 @@ import logging import math import queue from datetime import datetime +from enum import Enum from time import perf_counter from typing import Any, Dict, List, Optional @@ -10,7 +11,7 @@ from langchain.schema import ChatMessage from llama_cpp.llama import Llama from transformers import AutoTokenizer -from khoj.database.adapters import ConversationAdapters +from khoj.database.adapters import ConversationAdapters, ais_user_subscribed from khoj.database.models import ChatModelOptions, ClientApplication, KhojUser from khoj.processor.conversation.offline.utils import download_model, infer_max_tokens from khoj.utils import state @@ -75,6 +76,26 @@ class ThreadedGenerator: self.queue.put(StopIteration) +def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="AI") -> str: + chat_history = "" + for chat in conversation_history.get("chat", [])[-n:]: + if chat["by"] == "khoj" and chat["intent"].get("type") in ["remember", "reminder", "summarize"]: + chat_history += f"User: {chat['intent']['query']}\n" + chat_history += f"{agent_name}: {chat['message']}\n" + elif chat["by"] == "khoj" and ("text-to-image" in chat["intent"].get("type")): + chat_history += f"User: {chat['intent']['query']}\n" + chat_history += f"{agent_name}: [generated image redacted for space]\n" + return chat_history + + +class ChatEvent(Enum): + START_LLM_RESPONSE = "start_llm_response" + END_LLM_RESPONSE = "end_llm_response" + MESSAGE = "message" + REFERENCES = "references" + STATUS = "status" + + def message_to_log( user_message, chat_response, user_message_metadata={}, khoj_message_metadata={}, conversation_log=[] ): diff --git a/src/khoj/processor/tools/run_code.py b/src/khoj/processor/tools/run_code.py new file mode 100644 index 00000000..2fd2e8d2 --- /dev/null +++ b/src/khoj/processor/tools/run_code.py @@ -0,0 +1,122 @@ +import asyncio +import datetime +import json +import logging +from typing import Any, Callable, List, Optional + +import aiohttp + +from khoj.database.adapters import ais_user_subscribed +from khoj.database.models import Agent, KhojUser +from khoj.processor.conversation import prompts +from khoj.processor.conversation.helpers import send_message_to_model_wrapper +from khoj.processor.conversation.utils import ( + ChatEvent, + construct_chat_history, + remove_json_codeblock, +) +from khoj.utils.helpers import timer +from khoj.utils.rawconfig import LocationData + +logger = logging.getLogger(__name__) + + +async def run_code( + query: str, + conversation_history: dict, + location_data: LocationData, + user: KhojUser, + send_status_func: Optional[Callable] = None, + uploaded_image_url: str = None, + agent: Agent = None, + sandbox_url: str = "http://localhost:8080", +): + # Generate Code + if send_status_func: + async for event in send_status_func(f"**Generate code snippets** for {query}"): + yield {ChatEvent.STATUS: event} + try: + with timer("Chat actor: Generate programs to execute", logger): + codes = await generate_python_code( + query, conversation_history, location_data, user, uploaded_image_url, agent + ) + except Exception as e: + raise ValueError(f"Failed to generate code for {query} with error: {e}") + + # Run Code + if send_status_func: + async for event in send_status_func(f"**Running {len(codes)} code snippets**"): + yield {ChatEvent.STATUS: event} + try: + tasks = [execute_sandboxed_python(code, sandbox_url) for code in codes] + with timer("Chat actor: Execute generated programs", logger): + results = await asyncio.gather(*tasks) + for result in results: + code = result.pop("code") + logger.info(f"Executed Code:\n--@@--\n{code}\n--@@--Result:\n--@@--\n{result}\n--@@--") + yield {query: {"code": code, "results": result}} + except Exception as e: + raise ValueError(f"Failed to run code for {query} with error: {e}") + + +async def generate_python_code( + q: str, + conversation_history: dict, + location_data: LocationData, + user: KhojUser, + uploaded_image_url: str = None, + agent: Agent = None, +) -> List[str]: + location = f"{location_data}" if location_data else "Unknown" + username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else "" + subscribed = await ais_user_subscribed(user) + chat_history = construct_chat_history(conversation_history) + + utc_date = datetime.datetime.now(datetime.timezone.utc).strftime("%Y-%m-%d") + personality_context = ( + prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else "" + ) + + code_generation_prompt = prompts.python_code_generation_prompt.format( + current_date=utc_date, + query=q, + chat_history=chat_history, + location=location, + username=username, + personality_context=personality_context, + ) + + response = await send_message_to_model_wrapper( + code_generation_prompt, + uploaded_image_url=uploaded_image_url, + response_type="json_object", + subscribed=subscribed, + ) + + # Validate that the response is a non-empty, JSON-serializable list + response = response.strip() + response = remove_json_codeblock(response) + response = json.loads(response) + codes = [code.strip() for code in response["codes"] if code.strip()] + + if not isinstance(codes, list) or not codes or len(codes) == 0: + raise ValueError + return codes + + +async def execute_sandboxed_python(code: str, sandbox_url: str = "http://localhost:8080") -> dict[str, Any]: + """ + Takes code to run as a string and calls the terrarium API to execute it. + Returns the result of the code execution as a dictionary. + """ + headers = {"Content-Type": "application/json"} + data = {"code": code} + + async with aiohttp.ClientSession() as session: + async with session.post(sandbox_url, json=data, headers=headers) as response: + if response.status == 200: + result: dict[str, Any] = await response.json() + result["code"] = code + return result + else: + return {"code": code, "success": False, "std_err": f"Failed to execute code with {response.status}"} diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index cdf16bd9..82d6f351 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -19,7 +19,6 @@ from khoj.database.adapters import ( AgentAdapters, ConversationAdapters, EntryAdapters, - FileObjectAdapters, PublicConversationAdapters, aget_user_name, ) @@ -29,6 +28,7 @@ from khoj.processor.conversation.utils import save_to_conversation_log from khoj.processor.image.generate import text_to_image from khoj.processor.speech.text_to_speech import generate_text_to_speech from khoj.processor.tools.online_search import read_webpages, search_online +from khoj.processor.tools.run_code import run_code from khoj.routers.api import extract_references_and_questions from khoj.routers.helpers import ( ApiUserRateLimiter, @@ -46,7 +46,6 @@ from khoj.routers.helpers import ( is_query_empty, is_ready_to_chat, read_chat_stream, - run_code, update_telemetry_state, validate_conversation_config, ) diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index f17fe1f5..3b97e694 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -24,7 +24,6 @@ from typing import ( ) from urllib.parse import parse_qs, quote, urljoin, urlparse -import aiohttp import cron_descriptor import pytz import requests @@ -79,13 +78,16 @@ from khoj.processor.conversation.google.gemini_chat import ( converse_gemini, gemini_send_message_to_model, ) +from khoj.processor.conversation.helpers import send_message_to_model_wrapper from khoj.processor.conversation.offline.chat_model import ( converse_offline, send_message_to_model_offline, ) from khoj.processor.conversation.openai.gpt import converse, send_message_to_model from khoj.processor.conversation.utils import ( + ChatEvent, ThreadedGenerator, + construct_chat_history, generate_chatml_messages_with_context, remove_json_codeblock, save_to_conversation_log, @@ -208,18 +210,6 @@ def get_next_url(request: Request) -> str: return urljoin(str(request.base_url).rstrip("/"), next_path) -def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="AI") -> str: - chat_history = "" - for chat in conversation_history.get("chat", [])[-n:]: - if chat["by"] == "khoj" and chat["intent"].get("type") in ["remember", "reminder", "summarize"]: - chat_history += f"User: {chat['intent']['query']}\n" - chat_history += f"{agent_name}: {chat['message']}\n" - elif chat["by"] == "khoj" and ("text-to-image" in chat["intent"].get("type")): - chat_history += f"User: {chat['intent']['query']}\n" - chat_history += f"{agent_name}: [generated image redacted for space]\n" - return chat_history - - def get_conversation_command(query: str, any_references: bool = False) -> ConversationCommand: if query.startswith("/notes"): return ConversationCommand.Notes @@ -520,103 +510,6 @@ async def generate_online_subqueries( return [q] -async def run_code( - query: str, - conversation_history: dict, - location_data: LocationData, - user: KhojUser, - send_status_func: Optional[Callable] = None, - uploaded_image_url: str = None, - agent: Agent = None, - sandbox_url: str = "http://localhost:8080", -): - # Generate Code - if send_status_func: - async for event in send_status_func(f"**Generate code snippets** for {query}"): - yield {ChatEvent.STATUS: event} - try: - with timer("Chat actor: Generate programs to execute", logger): - codes = await generate_python_code( - query, conversation_history, location_data, user, uploaded_image_url, agent - ) - except Exception as e: - raise ValueError(f"Failed to generate code for {query} with error: {e}") - - # Run Code - if send_status_func: - async for event in send_status_func(f"**Running {len(codes)} code snippets**"): - yield {ChatEvent.STATUS: event} - try: - tasks = [execute_sandboxed_python(code, sandbox_url) for code in codes] - with timer("Chat actor: Execute generated programs", logger): - results = await asyncio.gather(*tasks) - for result in results: - code = result.pop("code") - logger.info(f"Executed Code:\n--@@--\n{code}\n--@@--Result:\n--@@--\n{result}\n--@@--") - yield {query: {"code": code, "results": result}} - except Exception as e: - raise ValueError(f"Failed to run code for {query} with error: {e}") - - -async def generate_python_code( - q: str, - conversation_history: dict, - location_data: LocationData, - user: KhojUser, - uploaded_image_url: str = None, - agent: Agent = None, -) -> List[str]: - location = f"{location_data}" if location_data else "Unknown" - username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else "" - chat_history = construct_chat_history(conversation_history) - - utc_date = datetime.utcnow().strftime("%Y-%m-%d") - personality_context = ( - prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else "" - ) - - code_generation_prompt = prompts.python_code_generation_prompt.format( - current_date=utc_date, - query=q, - chat_history=chat_history, - location=location, - username=username, - personality_context=personality_context, - ) - - response = await send_message_to_model_wrapper( - code_generation_prompt, uploaded_image_url=uploaded_image_url, response_type="json_object", user=user - ) - - # Validate that the response is a non-empty, JSON-serializable list - response = response.strip() - response = remove_json_codeblock(response) - response = json.loads(response) - codes = [code.strip() for code in response["codes"] if code.strip()] - - if not isinstance(codes, list) or not codes or len(codes) == 0: - raise ValueError - return codes - - -async def execute_sandboxed_python(code: str, sandbox_url: str = "http://localhost:8080") -> dict[str, Any]: - """ - Takes code to run as a string and calls the terrarium API to execute it. - Returns the result of the code execution as a dictionary. - """ - headers = {"Content-Type": "application/json"} - data = {"code": code} - - async with aiohttp.ClientSession() as session: - async with session.post(sandbox_url, json=data, headers=headers) as response: - if response.status == 200: - result: dict[str, Any] = await response.json() - result["code"] = code - return result - else: - return {"code": code, "success": False, "std_err": f"Failed to execute code with {response.status}"} - - async def schedule_query(q: str, conversation_history: dict, uploaded_image_url: str = None) -> Tuple[str, ...]: """ Schedule the date, time to run the query. Assume the server timezone is UTC. @@ -837,119 +730,6 @@ async def generate_better_image_prompt( return response -async def send_message_to_model_wrapper( - message: str, - system_message: str = "", - response_type: str = "text", - chat_model_option: ChatModelOptions = None, - subscribed: bool = False, - uploaded_image_url: str = None, -): - conversation_config: ChatModelOptions = ( - chat_model_option or await ConversationAdapters.aget_default_conversation_config() - ) - - vision_available = conversation_config.vision_enabled - if not vision_available and uploaded_image_url: - vision_enabled_config = await ConversationAdapters.aget_vision_enabled_config() - if vision_enabled_config: - conversation_config = vision_enabled_config - vision_available = True - - chat_model = conversation_config.chat_model - max_tokens = ( - conversation_config.subscribed_max_prompt_size - if subscribed and conversation_config.subscribed_max_prompt_size - else conversation_config.max_prompt_size - ) - tokenizer = conversation_config.tokenizer - model_type = conversation_config.model_type - vision_available = conversation_config.vision_enabled - - if model_type == ChatModelOptions.ModelType.OFFLINE: - if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None: - state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens) - - loaded_model = state.offline_chat_processor_config.loaded_model - truncated_messages = generate_chatml_messages_with_context( - user_message=message, - system_message=system_message, - model_name=chat_model, - loaded_model=loaded_model, - tokenizer_name=tokenizer, - max_prompt_size=max_tokens, - vision_enabled=vision_available, - model_type=conversation_config.model_type, - ) - - return send_message_to_model_offline( - messages=truncated_messages, - loaded_model=loaded_model, - model=chat_model, - max_prompt_size=max_tokens, - streaming=False, - response_type=response_type, - ) - - elif model_type == ChatModelOptions.ModelType.OPENAI: - openai_chat_config = conversation_config.openai_config - api_key = openai_chat_config.api_key - api_base_url = openai_chat_config.api_base_url - truncated_messages = generate_chatml_messages_with_context( - user_message=message, - system_message=system_message, - model_name=chat_model, - max_prompt_size=max_tokens, - tokenizer_name=tokenizer, - vision_enabled=vision_available, - uploaded_image_url=uploaded_image_url, - model_type=conversation_config.model_type, - ) - - return send_message_to_model( - messages=truncated_messages, - api_key=api_key, - model=chat_model, - response_type=response_type, - api_base_url=api_base_url, - ) - elif model_type == ChatModelOptions.ModelType.ANTHROPIC: - api_key = conversation_config.openai_config.api_key - truncated_messages = generate_chatml_messages_with_context( - user_message=message, - system_message=system_message, - model_name=chat_model, - max_prompt_size=max_tokens, - tokenizer_name=tokenizer, - vision_enabled=vision_available, - uploaded_image_url=uploaded_image_url, - model_type=conversation_config.model_type, - ) - - return anthropic_send_message_to_model( - messages=truncated_messages, - api_key=api_key, - model=chat_model, - ) - elif model_type == ChatModelOptions.ModelType.GOOGLE: - api_key = conversation_config.openai_config.api_key - truncated_messages = generate_chatml_messages_with_context( - user_message=message, - system_message=system_message, - model_name=chat_model, - max_prompt_size=max_tokens, - tokenizer_name=tokenizer, - vision_enabled=vision_available, - uploaded_image_url=uploaded_image_url, - ) - - return gemini_send_message_to_model( - messages=truncated_messages, api_key=api_key, model=chat_model, response_type=response_type - ) - else: - raise HTTPException(status_code=500, detail="Invalid conversation config") - - def send_message_to_model_wrapper_sync( message: str, system_message: str = "", @@ -1540,14 +1320,6 @@ Manage your automations [here](/automations). """.strip() -class ChatEvent(Enum): - START_LLM_RESPONSE = "start_llm_response" - END_LLM_RESPONSE = "end_llm_response" - MESSAGE = "message" - REFERENCES = "references" - STATUS = "status" - - class MessageProcessor: def __init__(self): self.references = {}