mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-27 17:35:07 +01:00
Refactor Run Code tool into separate module and modularize code functions
Move construct_chat_history and ChatEvent enum into conversation.utils and move send_message_to_model_wrapper to conversation.helper to modularize code. And start thinning out the bloated routers.helper - conversation.util components are shared functions that conversation child packages can use. - conversation.helper components can't be imported by conversation packages but it can use these child packages This division allows better modularity while avoiding circular import dependencies
This commit is contained in:
parent
8044733201
commit
a98f97ed5e
5 changed files with 274 additions and 234 deletions
126
src/khoj/processor/conversation/helpers.py
Normal file
126
src/khoj/processor/conversation/helpers.py
Normal file
|
@ -0,0 +1,126 @@
|
|||
from fastapi import HTTPException
|
||||
|
||||
from khoj.database.adapters import ConversationAdapters, ais_user_subscribed
|
||||
from khoj.database.models import ChatModelOptions, KhojUser
|
||||
from khoj.processor.conversation.anthropic.anthropic_chat import (
|
||||
anthropic_send_message_to_model,
|
||||
)
|
||||
from khoj.processor.conversation.google.gemini_chat import gemini_send_message_to_model
|
||||
from khoj.processor.conversation.offline.chat_model import send_message_to_model_offline
|
||||
from khoj.processor.conversation.openai.gpt import send_message_to_model
|
||||
from khoj.processor.conversation.utils import generate_chatml_messages_with_context
|
||||
from khoj.utils import state
|
||||
from khoj.utils.config import OfflineChatProcessorModel
|
||||
|
||||
|
||||
async def send_message_to_model_wrapper(
|
||||
message: str,
|
||||
system_message: str = "",
|
||||
response_type: str = "text",
|
||||
chat_model_option: ChatModelOptions = None,
|
||||
subscribed: bool = False,
|
||||
uploaded_image_url: str = None,
|
||||
):
|
||||
conversation_config: ChatModelOptions = (
|
||||
chat_model_option or await ConversationAdapters.aget_default_conversation_config()
|
||||
)
|
||||
|
||||
vision_available = conversation_config.vision_enabled
|
||||
if not vision_available and uploaded_image_url:
|
||||
vision_enabled_config = await ConversationAdapters.aget_vision_enabled_config()
|
||||
if vision_enabled_config:
|
||||
conversation_config = vision_enabled_config
|
||||
vision_available = True
|
||||
|
||||
chat_model = conversation_config.chat_model
|
||||
max_tokens = (
|
||||
conversation_config.subscribed_max_prompt_size
|
||||
if subscribed and conversation_config.subscribed_max_prompt_size
|
||||
else conversation_config.max_prompt_size
|
||||
)
|
||||
tokenizer = conversation_config.tokenizer
|
||||
model_type = conversation_config.model_type
|
||||
vision_available = conversation_config.vision_enabled
|
||||
|
||||
if model_type == ChatModelOptions.ModelType.OFFLINE:
|
||||
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
|
||||
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
|
||||
|
||||
loaded_model = state.offline_chat_processor_config.loaded_model
|
||||
truncated_messages = generate_chatml_messages_with_context(
|
||||
user_message=message,
|
||||
system_message=system_message,
|
||||
model_name=chat_model,
|
||||
loaded_model=loaded_model,
|
||||
tokenizer_name=tokenizer,
|
||||
max_prompt_size=max_tokens,
|
||||
vision_enabled=vision_available,
|
||||
model_type=conversation_config.model_type,
|
||||
)
|
||||
|
||||
return send_message_to_model_offline(
|
||||
messages=truncated_messages,
|
||||
loaded_model=loaded_model,
|
||||
model=chat_model,
|
||||
max_prompt_size=max_tokens,
|
||||
streaming=False,
|
||||
response_type=response_type,
|
||||
)
|
||||
|
||||
elif model_type == ChatModelOptions.ModelType.OPENAI:
|
||||
openai_chat_config = conversation_config.openai_config
|
||||
api_key = openai_chat_config.api_key
|
||||
api_base_url = openai_chat_config.api_base_url
|
||||
truncated_messages = generate_chatml_messages_with_context(
|
||||
user_message=message,
|
||||
system_message=system_message,
|
||||
model_name=chat_model,
|
||||
max_prompt_size=max_tokens,
|
||||
tokenizer_name=tokenizer,
|
||||
vision_enabled=vision_available,
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
model_type=conversation_config.model_type,
|
||||
)
|
||||
|
||||
return send_message_to_model(
|
||||
messages=truncated_messages,
|
||||
api_key=api_key,
|
||||
model=chat_model,
|
||||
response_type=response_type,
|
||||
api_base_url=api_base_url,
|
||||
)
|
||||
elif model_type == ChatModelOptions.ModelType.ANTHROPIC:
|
||||
api_key = conversation_config.openai_config.api_key
|
||||
truncated_messages = generate_chatml_messages_with_context(
|
||||
user_message=message,
|
||||
system_message=system_message,
|
||||
model_name=chat_model,
|
||||
max_prompt_size=max_tokens,
|
||||
tokenizer_name=tokenizer,
|
||||
vision_enabled=vision_available,
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
model_type=conversation_config.model_type,
|
||||
)
|
||||
|
||||
return anthropic_send_message_to_model(
|
||||
messages=truncated_messages,
|
||||
api_key=api_key,
|
||||
model=chat_model,
|
||||
)
|
||||
elif model_type == ChatModelOptions.ModelType.GOOGLE:
|
||||
api_key = conversation_config.openai_config.api_key
|
||||
truncated_messages = generate_chatml_messages_with_context(
|
||||
user_message=message,
|
||||
system_message=system_message,
|
||||
model_name=chat_model,
|
||||
max_prompt_size=max_tokens,
|
||||
tokenizer_name=tokenizer,
|
||||
vision_enabled=vision_available,
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
)
|
||||
|
||||
return gemini_send_message_to_model(
|
||||
messages=truncated_messages, api_key=api_key, model=chat_model, response_type=response_type
|
||||
)
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail="Invalid conversation config")
|
|
@ -2,6 +2,7 @@ import logging
|
|||
import math
|
||||
import queue
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from time import perf_counter
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
@ -10,7 +11,7 @@ from langchain.schema import ChatMessage
|
|||
from llama_cpp.llama import Llama
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from khoj.database.adapters import ConversationAdapters
|
||||
from khoj.database.adapters import ConversationAdapters, ais_user_subscribed
|
||||
from khoj.database.models import ChatModelOptions, ClientApplication, KhojUser
|
||||
from khoj.processor.conversation.offline.utils import download_model, infer_max_tokens
|
||||
from khoj.utils import state
|
||||
|
@ -75,6 +76,26 @@ class ThreadedGenerator:
|
|||
self.queue.put(StopIteration)
|
||||
|
||||
|
||||
def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="AI") -> str:
|
||||
chat_history = ""
|
||||
for chat in conversation_history.get("chat", [])[-n:]:
|
||||
if chat["by"] == "khoj" and chat["intent"].get("type") in ["remember", "reminder", "summarize"]:
|
||||
chat_history += f"User: {chat['intent']['query']}\n"
|
||||
chat_history += f"{agent_name}: {chat['message']}\n"
|
||||
elif chat["by"] == "khoj" and ("text-to-image" in chat["intent"].get("type")):
|
||||
chat_history += f"User: {chat['intent']['query']}\n"
|
||||
chat_history += f"{agent_name}: [generated image redacted for space]\n"
|
||||
return chat_history
|
||||
|
||||
|
||||
class ChatEvent(Enum):
|
||||
START_LLM_RESPONSE = "start_llm_response"
|
||||
END_LLM_RESPONSE = "end_llm_response"
|
||||
MESSAGE = "message"
|
||||
REFERENCES = "references"
|
||||
STATUS = "status"
|
||||
|
||||
|
||||
def message_to_log(
|
||||
user_message, chat_response, user_message_metadata={}, khoj_message_metadata={}, conversation_log=[]
|
||||
):
|
||||
|
|
122
src/khoj/processor/tools/run_code.py
Normal file
122
src/khoj/processor/tools/run_code.py
Normal file
|
@ -0,0 +1,122 @@
|
|||
import asyncio
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Callable, List, Optional
|
||||
|
||||
import aiohttp
|
||||
|
||||
from khoj.database.adapters import ais_user_subscribed
|
||||
from khoj.database.models import Agent, KhojUser
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.helpers import send_message_to_model_wrapper
|
||||
from khoj.processor.conversation.utils import (
|
||||
ChatEvent,
|
||||
construct_chat_history,
|
||||
remove_json_codeblock,
|
||||
)
|
||||
from khoj.utils.helpers import timer
|
||||
from khoj.utils.rawconfig import LocationData
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def run_code(
|
||||
query: str,
|
||||
conversation_history: dict,
|
||||
location_data: LocationData,
|
||||
user: KhojUser,
|
||||
send_status_func: Optional[Callable] = None,
|
||||
uploaded_image_url: str = None,
|
||||
agent: Agent = None,
|
||||
sandbox_url: str = "http://localhost:8080",
|
||||
):
|
||||
# Generate Code
|
||||
if send_status_func:
|
||||
async for event in send_status_func(f"**Generate code snippets** for {query}"):
|
||||
yield {ChatEvent.STATUS: event}
|
||||
try:
|
||||
with timer("Chat actor: Generate programs to execute", logger):
|
||||
codes = await generate_python_code(
|
||||
query, conversation_history, location_data, user, uploaded_image_url, agent
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to generate code for {query} with error: {e}")
|
||||
|
||||
# Run Code
|
||||
if send_status_func:
|
||||
async for event in send_status_func(f"**Running {len(codes)} code snippets**"):
|
||||
yield {ChatEvent.STATUS: event}
|
||||
try:
|
||||
tasks = [execute_sandboxed_python(code, sandbox_url) for code in codes]
|
||||
with timer("Chat actor: Execute generated programs", logger):
|
||||
results = await asyncio.gather(*tasks)
|
||||
for result in results:
|
||||
code = result.pop("code")
|
||||
logger.info(f"Executed Code:\n--@@--\n{code}\n--@@--Result:\n--@@--\n{result}\n--@@--")
|
||||
yield {query: {"code": code, "results": result}}
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to run code for {query} with error: {e}")
|
||||
|
||||
|
||||
async def generate_python_code(
|
||||
q: str,
|
||||
conversation_history: dict,
|
||||
location_data: LocationData,
|
||||
user: KhojUser,
|
||||
uploaded_image_url: str = None,
|
||||
agent: Agent = None,
|
||||
) -> List[str]:
|
||||
location = f"{location_data}" if location_data else "Unknown"
|
||||
username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else ""
|
||||
subscribed = await ais_user_subscribed(user)
|
||||
chat_history = construct_chat_history(conversation_history)
|
||||
|
||||
utc_date = datetime.datetime.now(datetime.timezone.utc).strftime("%Y-%m-%d")
|
||||
personality_context = (
|
||||
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
|
||||
)
|
||||
|
||||
code_generation_prompt = prompts.python_code_generation_prompt.format(
|
||||
current_date=utc_date,
|
||||
query=q,
|
||||
chat_history=chat_history,
|
||||
location=location,
|
||||
username=username,
|
||||
personality_context=personality_context,
|
||||
)
|
||||
|
||||
response = await send_message_to_model_wrapper(
|
||||
code_generation_prompt,
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
response_type="json_object",
|
||||
subscribed=subscribed,
|
||||
)
|
||||
|
||||
# Validate that the response is a non-empty, JSON-serializable list
|
||||
response = response.strip()
|
||||
response = remove_json_codeblock(response)
|
||||
response = json.loads(response)
|
||||
codes = [code.strip() for code in response["codes"] if code.strip()]
|
||||
|
||||
if not isinstance(codes, list) or not codes or len(codes) == 0:
|
||||
raise ValueError
|
||||
return codes
|
||||
|
||||
|
||||
async def execute_sandboxed_python(code: str, sandbox_url: str = "http://localhost:8080") -> dict[str, Any]:
|
||||
"""
|
||||
Takes code to run as a string and calls the terrarium API to execute it.
|
||||
Returns the result of the code execution as a dictionary.
|
||||
"""
|
||||
headers = {"Content-Type": "application/json"}
|
||||
data = {"code": code}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(sandbox_url, json=data, headers=headers) as response:
|
||||
if response.status == 200:
|
||||
result: dict[str, Any] = await response.json()
|
||||
result["code"] = code
|
||||
return result
|
||||
else:
|
||||
return {"code": code, "success": False, "std_err": f"Failed to execute code with {response.status}"}
|
|
@ -19,7 +19,6 @@ from khoj.database.adapters import (
|
|||
AgentAdapters,
|
||||
ConversationAdapters,
|
||||
EntryAdapters,
|
||||
FileObjectAdapters,
|
||||
PublicConversationAdapters,
|
||||
aget_user_name,
|
||||
)
|
||||
|
@ -29,6 +28,7 @@ from khoj.processor.conversation.utils import save_to_conversation_log
|
|||
from khoj.processor.image.generate import text_to_image
|
||||
from khoj.processor.speech.text_to_speech import generate_text_to_speech
|
||||
from khoj.processor.tools.online_search import read_webpages, search_online
|
||||
from khoj.processor.tools.run_code import run_code
|
||||
from khoj.routers.api import extract_references_and_questions
|
||||
from khoj.routers.helpers import (
|
||||
ApiUserRateLimiter,
|
||||
|
@ -46,7 +46,6 @@ from khoj.routers.helpers import (
|
|||
is_query_empty,
|
||||
is_ready_to_chat,
|
||||
read_chat_stream,
|
||||
run_code,
|
||||
update_telemetry_state,
|
||||
validate_conversation_config,
|
||||
)
|
||||
|
|
|
@ -24,7 +24,6 @@ from typing import (
|
|||
)
|
||||
from urllib.parse import parse_qs, quote, urljoin, urlparse
|
||||
|
||||
import aiohttp
|
||||
import cron_descriptor
|
||||
import pytz
|
||||
import requests
|
||||
|
@ -79,13 +78,16 @@ from khoj.processor.conversation.google.gemini_chat import (
|
|||
converse_gemini,
|
||||
gemini_send_message_to_model,
|
||||
)
|
||||
from khoj.processor.conversation.helpers import send_message_to_model_wrapper
|
||||
from khoj.processor.conversation.offline.chat_model import (
|
||||
converse_offline,
|
||||
send_message_to_model_offline,
|
||||
)
|
||||
from khoj.processor.conversation.openai.gpt import converse, send_message_to_model
|
||||
from khoj.processor.conversation.utils import (
|
||||
ChatEvent,
|
||||
ThreadedGenerator,
|
||||
construct_chat_history,
|
||||
generate_chatml_messages_with_context,
|
||||
remove_json_codeblock,
|
||||
save_to_conversation_log,
|
||||
|
@ -208,18 +210,6 @@ def get_next_url(request: Request) -> str:
|
|||
return urljoin(str(request.base_url).rstrip("/"), next_path)
|
||||
|
||||
|
||||
def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="AI") -> str:
|
||||
chat_history = ""
|
||||
for chat in conversation_history.get("chat", [])[-n:]:
|
||||
if chat["by"] == "khoj" and chat["intent"].get("type") in ["remember", "reminder", "summarize"]:
|
||||
chat_history += f"User: {chat['intent']['query']}\n"
|
||||
chat_history += f"{agent_name}: {chat['message']}\n"
|
||||
elif chat["by"] == "khoj" and ("text-to-image" in chat["intent"].get("type")):
|
||||
chat_history += f"User: {chat['intent']['query']}\n"
|
||||
chat_history += f"{agent_name}: [generated image redacted for space]\n"
|
||||
return chat_history
|
||||
|
||||
|
||||
def get_conversation_command(query: str, any_references: bool = False) -> ConversationCommand:
|
||||
if query.startswith("/notes"):
|
||||
return ConversationCommand.Notes
|
||||
|
@ -520,103 +510,6 @@ async def generate_online_subqueries(
|
|||
return [q]
|
||||
|
||||
|
||||
async def run_code(
|
||||
query: str,
|
||||
conversation_history: dict,
|
||||
location_data: LocationData,
|
||||
user: KhojUser,
|
||||
send_status_func: Optional[Callable] = None,
|
||||
uploaded_image_url: str = None,
|
||||
agent: Agent = None,
|
||||
sandbox_url: str = "http://localhost:8080",
|
||||
):
|
||||
# Generate Code
|
||||
if send_status_func:
|
||||
async for event in send_status_func(f"**Generate code snippets** for {query}"):
|
||||
yield {ChatEvent.STATUS: event}
|
||||
try:
|
||||
with timer("Chat actor: Generate programs to execute", logger):
|
||||
codes = await generate_python_code(
|
||||
query, conversation_history, location_data, user, uploaded_image_url, agent
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to generate code for {query} with error: {e}")
|
||||
|
||||
# Run Code
|
||||
if send_status_func:
|
||||
async for event in send_status_func(f"**Running {len(codes)} code snippets**"):
|
||||
yield {ChatEvent.STATUS: event}
|
||||
try:
|
||||
tasks = [execute_sandboxed_python(code, sandbox_url) for code in codes]
|
||||
with timer("Chat actor: Execute generated programs", logger):
|
||||
results = await asyncio.gather(*tasks)
|
||||
for result in results:
|
||||
code = result.pop("code")
|
||||
logger.info(f"Executed Code:\n--@@--\n{code}\n--@@--Result:\n--@@--\n{result}\n--@@--")
|
||||
yield {query: {"code": code, "results": result}}
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to run code for {query} with error: {e}")
|
||||
|
||||
|
||||
async def generate_python_code(
|
||||
q: str,
|
||||
conversation_history: dict,
|
||||
location_data: LocationData,
|
||||
user: KhojUser,
|
||||
uploaded_image_url: str = None,
|
||||
agent: Agent = None,
|
||||
) -> List[str]:
|
||||
location = f"{location_data}" if location_data else "Unknown"
|
||||
username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else ""
|
||||
chat_history = construct_chat_history(conversation_history)
|
||||
|
||||
utc_date = datetime.utcnow().strftime("%Y-%m-%d")
|
||||
personality_context = (
|
||||
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
|
||||
)
|
||||
|
||||
code_generation_prompt = prompts.python_code_generation_prompt.format(
|
||||
current_date=utc_date,
|
||||
query=q,
|
||||
chat_history=chat_history,
|
||||
location=location,
|
||||
username=username,
|
||||
personality_context=personality_context,
|
||||
)
|
||||
|
||||
response = await send_message_to_model_wrapper(
|
||||
code_generation_prompt, uploaded_image_url=uploaded_image_url, response_type="json_object", user=user
|
||||
)
|
||||
|
||||
# Validate that the response is a non-empty, JSON-serializable list
|
||||
response = response.strip()
|
||||
response = remove_json_codeblock(response)
|
||||
response = json.loads(response)
|
||||
codes = [code.strip() for code in response["codes"] if code.strip()]
|
||||
|
||||
if not isinstance(codes, list) or not codes or len(codes) == 0:
|
||||
raise ValueError
|
||||
return codes
|
||||
|
||||
|
||||
async def execute_sandboxed_python(code: str, sandbox_url: str = "http://localhost:8080") -> dict[str, Any]:
|
||||
"""
|
||||
Takes code to run as a string and calls the terrarium API to execute it.
|
||||
Returns the result of the code execution as a dictionary.
|
||||
"""
|
||||
headers = {"Content-Type": "application/json"}
|
||||
data = {"code": code}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(sandbox_url, json=data, headers=headers) as response:
|
||||
if response.status == 200:
|
||||
result: dict[str, Any] = await response.json()
|
||||
result["code"] = code
|
||||
return result
|
||||
else:
|
||||
return {"code": code, "success": False, "std_err": f"Failed to execute code with {response.status}"}
|
||||
|
||||
|
||||
async def schedule_query(q: str, conversation_history: dict, uploaded_image_url: str = None) -> Tuple[str, ...]:
|
||||
"""
|
||||
Schedule the date, time to run the query. Assume the server timezone is UTC.
|
||||
|
@ -837,119 +730,6 @@ async def generate_better_image_prompt(
|
|||
return response
|
||||
|
||||
|
||||
async def send_message_to_model_wrapper(
|
||||
message: str,
|
||||
system_message: str = "",
|
||||
response_type: str = "text",
|
||||
chat_model_option: ChatModelOptions = None,
|
||||
subscribed: bool = False,
|
||||
uploaded_image_url: str = None,
|
||||
):
|
||||
conversation_config: ChatModelOptions = (
|
||||
chat_model_option or await ConversationAdapters.aget_default_conversation_config()
|
||||
)
|
||||
|
||||
vision_available = conversation_config.vision_enabled
|
||||
if not vision_available and uploaded_image_url:
|
||||
vision_enabled_config = await ConversationAdapters.aget_vision_enabled_config()
|
||||
if vision_enabled_config:
|
||||
conversation_config = vision_enabled_config
|
||||
vision_available = True
|
||||
|
||||
chat_model = conversation_config.chat_model
|
||||
max_tokens = (
|
||||
conversation_config.subscribed_max_prompt_size
|
||||
if subscribed and conversation_config.subscribed_max_prompt_size
|
||||
else conversation_config.max_prompt_size
|
||||
)
|
||||
tokenizer = conversation_config.tokenizer
|
||||
model_type = conversation_config.model_type
|
||||
vision_available = conversation_config.vision_enabled
|
||||
|
||||
if model_type == ChatModelOptions.ModelType.OFFLINE:
|
||||
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
|
||||
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
|
||||
|
||||
loaded_model = state.offline_chat_processor_config.loaded_model
|
||||
truncated_messages = generate_chatml_messages_with_context(
|
||||
user_message=message,
|
||||
system_message=system_message,
|
||||
model_name=chat_model,
|
||||
loaded_model=loaded_model,
|
||||
tokenizer_name=tokenizer,
|
||||
max_prompt_size=max_tokens,
|
||||
vision_enabled=vision_available,
|
||||
model_type=conversation_config.model_type,
|
||||
)
|
||||
|
||||
return send_message_to_model_offline(
|
||||
messages=truncated_messages,
|
||||
loaded_model=loaded_model,
|
||||
model=chat_model,
|
||||
max_prompt_size=max_tokens,
|
||||
streaming=False,
|
||||
response_type=response_type,
|
||||
)
|
||||
|
||||
elif model_type == ChatModelOptions.ModelType.OPENAI:
|
||||
openai_chat_config = conversation_config.openai_config
|
||||
api_key = openai_chat_config.api_key
|
||||
api_base_url = openai_chat_config.api_base_url
|
||||
truncated_messages = generate_chatml_messages_with_context(
|
||||
user_message=message,
|
||||
system_message=system_message,
|
||||
model_name=chat_model,
|
||||
max_prompt_size=max_tokens,
|
||||
tokenizer_name=tokenizer,
|
||||
vision_enabled=vision_available,
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
model_type=conversation_config.model_type,
|
||||
)
|
||||
|
||||
return send_message_to_model(
|
||||
messages=truncated_messages,
|
||||
api_key=api_key,
|
||||
model=chat_model,
|
||||
response_type=response_type,
|
||||
api_base_url=api_base_url,
|
||||
)
|
||||
elif model_type == ChatModelOptions.ModelType.ANTHROPIC:
|
||||
api_key = conversation_config.openai_config.api_key
|
||||
truncated_messages = generate_chatml_messages_with_context(
|
||||
user_message=message,
|
||||
system_message=system_message,
|
||||
model_name=chat_model,
|
||||
max_prompt_size=max_tokens,
|
||||
tokenizer_name=tokenizer,
|
||||
vision_enabled=vision_available,
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
model_type=conversation_config.model_type,
|
||||
)
|
||||
|
||||
return anthropic_send_message_to_model(
|
||||
messages=truncated_messages,
|
||||
api_key=api_key,
|
||||
model=chat_model,
|
||||
)
|
||||
elif model_type == ChatModelOptions.ModelType.GOOGLE:
|
||||
api_key = conversation_config.openai_config.api_key
|
||||
truncated_messages = generate_chatml_messages_with_context(
|
||||
user_message=message,
|
||||
system_message=system_message,
|
||||
model_name=chat_model,
|
||||
max_prompt_size=max_tokens,
|
||||
tokenizer_name=tokenizer,
|
||||
vision_enabled=vision_available,
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
)
|
||||
|
||||
return gemini_send_message_to_model(
|
||||
messages=truncated_messages, api_key=api_key, model=chat_model, response_type=response_type
|
||||
)
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail="Invalid conversation config")
|
||||
|
||||
|
||||
def send_message_to_model_wrapper_sync(
|
||||
message: str,
|
||||
system_message: str = "",
|
||||
|
@ -1540,14 +1320,6 @@ Manage your automations [here](/automations).
|
|||
""".strip()
|
||||
|
||||
|
||||
class ChatEvent(Enum):
|
||||
START_LLM_RESPONSE = "start_llm_response"
|
||||
END_LLM_RESPONSE = "end_llm_response"
|
||||
MESSAGE = "message"
|
||||
REFERENCES = "references"
|
||||
STATUS = "status"
|
||||
|
||||
|
||||
class MessageProcessor:
|
||||
def __init__(self):
|
||||
self.references = {}
|
||||
|
|
Loading…
Reference in a new issue