Track, return cost and usage metrics in chat api response

- Track input, output token usage and cost for interactions
  via chat api with openai, anthropic and google chat models

- Get usage metadata from OpenAI using stream_options
- Handle openai proxies that do not support passing usage in response

- Add new usage, end response  events returned by chat api.
  - This can be optionally consumed by clients at a later point
  - Update streaming clients to mark message as completed after new
    end response event, not after end llm response event
- Ensure usage data from final response generation step is included
  - Pass usage data after llm response complete. This allows gathering
    token usage and cost for the final response generation step across
    streaming and non-streaming modes
This commit is contained in:
Debanjum 2024-11-18 18:23:05 -08:00
parent 7bdc9590dd
commit c53c3db96b
10 changed files with 139 additions and 38 deletions

View file

@ -945,7 +945,7 @@ export class KhojChatView extends KhojPaneView {
console.log("Started streaming", new Date()); console.log("Started streaming", new Date());
} else if (chunk.type === 'end_llm_response') { } else if (chunk.type === 'end_llm_response') {
console.log("Stopped streaming", new Date()); console.log("Stopped streaming", new Date());
} else if (chunk.type === 'end_response') {
// Automatically respond with voice if the subscribed user has sent voice message // Automatically respond with voice if the subscribed user has sent voice message
if (this.chatMessageState.isVoice && this.setting.userInfo?.is_active) if (this.chatMessageState.isVoice && this.setting.userInfo?.is_active)
this.textToSpeech(this.chatMessageState.rawResponse); this.textToSpeech(this.chatMessageState.rawResponse);

View file

@ -133,7 +133,7 @@ export function processMessageChunk(
console.log(`Started streaming: ${new Date()}`); console.log(`Started streaming: ${new Date()}`);
} else if (chunk.type === "end_llm_response") { } else if (chunk.type === "end_llm_response") {
console.log(`Completed streaming: ${new Date()}`); console.log(`Completed streaming: ${new Date()}`);
} else if (chunk.type === "end_response") {
// Append any references after all the data has been streamed // Append any references after all the data has been streamed
if (codeContext) currentMessage.codeContext = codeContext; if (codeContext) currentMessage.codeContext = codeContext;
if (onlineContext) currentMessage.onlineContext = onlineContext; if (onlineContext) currentMessage.onlineContext = onlineContext;

View file

@ -18,7 +18,7 @@ from khoj.processor.conversation.utils import (
get_image_from_url, get_image_from_url,
) )
from khoj.utils import state from khoj.utils import state
from khoj.utils.helpers import in_debug_mode, is_none_or_empty from khoj.utils.helpers import get_chat_usage_metrics, in_debug_mode, is_none_or_empty
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -59,6 +59,7 @@ def anthropic_completion_with_backoff(
aggregated_response = "{" if response_type == "json_object" else "" aggregated_response = "{" if response_type == "json_object" else ""
max_tokens = max_tokens or DEFAULT_MAX_TOKENS_ANTHROPIC max_tokens = max_tokens or DEFAULT_MAX_TOKENS_ANTHROPIC
final_message = None
model_kwargs = model_kwargs or dict() model_kwargs = model_kwargs or dict()
if system_prompt: if system_prompt:
model_kwargs["system"] = system_prompt model_kwargs["system"] = system_prompt
@ -73,6 +74,12 @@ def anthropic_completion_with_backoff(
) as stream: ) as stream:
for text in stream.text_stream: for text in stream.text_stream:
aggregated_response += text aggregated_response += text
final_message = stream.get_final_message()
# Calculate cost of chat
input_tokens = final_message.usage.input_tokens
output_tokens = final_message.usage.output_tokens
tracer["usage"] = get_chat_usage_metrics(model_name, input_tokens, output_tokens, tracer.get("usage"))
# Save conversation trace # Save conversation trace
tracer["chat_model"] = model_name tracer["chat_model"] = model_name
@ -126,6 +133,7 @@ def anthropic_llm_thread(
] ]
aggregated_response = "" aggregated_response = ""
final_message = None
with client.messages.stream( with client.messages.stream(
messages=formatted_messages, messages=formatted_messages,
model=model_name, # type: ignore model=model_name, # type: ignore
@ -138,6 +146,12 @@ def anthropic_llm_thread(
for text in stream.text_stream: for text in stream.text_stream:
aggregated_response += text aggregated_response += text
g.send(text) g.send(text)
final_message = stream.get_final_message()
# Calculate cost of chat
input_tokens = final_message.usage.input_tokens
output_tokens = final_message.usage.output_tokens
tracer["usage"] = get_chat_usage_metrics(model_name, input_tokens, output_tokens, tracer.get("usage"))
# Save conversation trace # Save conversation trace
tracer["chat_model"] = model_name tracer["chat_model"] = model_name

View file

@ -25,7 +25,7 @@ from khoj.processor.conversation.utils import (
get_image_from_url, get_image_from_url,
) )
from khoj.utils import state from khoj.utils import state
from khoj.utils.helpers import in_debug_mode, is_none_or_empty from khoj.utils.helpers import get_chat_usage_metrics, in_debug_mode, is_none_or_empty
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -68,6 +68,7 @@ def gemini_completion_with_backoff(
response = chat_session.send_message(formatted_messages[-1]["parts"]) response = chat_session.send_message(formatted_messages[-1]["parts"])
response_text = response.text response_text = response.text
except StopCandidateException as e: except StopCandidateException as e:
response = None
response_text, _ = handle_gemini_response(e.args) response_text, _ = handle_gemini_response(e.args)
# Respond with reason for stopping # Respond with reason for stopping
logger.warning( logger.warning(
@ -75,6 +76,11 @@ def gemini_completion_with_backoff(
+ f"Last Message by {messages[-1].role}: {messages[-1].content}" + f"Last Message by {messages[-1].role}: {messages[-1].content}"
) )
# Aggregate cost of chat
input_tokens = response.usage_metadata.prompt_token_count if response else 0
output_tokens = response.usage_metadata.candidates_token_count if response else 0
tracer["usage"] = get_chat_usage_metrics(model_name, input_tokens, output_tokens, tracer.get("usage"))
# Save conversation trace # Save conversation trace
tracer["chat_model"] = model_name tracer["chat_model"] = model_name
tracer["temperature"] = temperature tracer["temperature"] = temperature
@ -146,6 +152,11 @@ def gemini_llm_thread(
if stopped: if stopped:
raise StopCandidateException(message) raise StopCandidateException(message)
# Calculate cost of chat
input_tokens = chunk.usage_metadata.prompt_token_count
output_tokens = chunk.usage_metadata.candidates_token_count
tracer["usage"] = get_chat_usage_metrics(model_name, input_tokens, output_tokens, tracer.get("usage"))
# Save conversation trace # Save conversation trace
tracer["chat_model"] = model_name tracer["chat_model"] = model_name
tracer["temperature"] = temperature tracer["temperature"] = temperature

View file

@ -4,6 +4,8 @@ from threading import Thread
from typing import Dict from typing import Dict
import openai import openai
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from tenacity import ( from tenacity import (
before_sleep_log, before_sleep_log,
retry, retry,
@ -18,7 +20,7 @@ from khoj.processor.conversation.utils import (
commit_conversation_trace, commit_conversation_trace,
) )
from khoj.utils import state from khoj.utils import state
from khoj.utils.helpers import in_debug_mode from khoj.utils.helpers import get_chat_usage_metrics, in_debug_mode
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -64,27 +66,34 @@ def completion_with_backoff(
if os.getenv("KHOJ_LLM_SEED"): if os.getenv("KHOJ_LLM_SEED"):
model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED")) model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED"))
chat = client.chat.completions.create( chat: ChatCompletion | openai.Stream[ChatCompletionChunk] = client.chat.completions.create(
stream=stream,
messages=formatted_messages, # type: ignore messages=formatted_messages, # type: ignore
model=model, # type: ignore model=model, # type: ignore
stream=stream,
stream_options={"include_usage": True} if stream else {},
temperature=temperature, temperature=temperature,
timeout=20, timeout=20,
**(model_kwargs or dict()), **(model_kwargs or dict()),
) )
if not stream:
return chat.choices[0].message.content
aggregated_response = "" aggregated_response = ""
for chunk in chat: if not stream:
if len(chunk.choices) == 0: chunk = chat
continue aggregated_response = chunk.choices[0].message.content
delta_chunk = chunk.choices[0].delta # type: ignore else:
if isinstance(delta_chunk, str): for chunk in chat:
aggregated_response += delta_chunk if len(chunk.choices) == 0:
elif delta_chunk.content: continue
aggregated_response += delta_chunk.content delta_chunk = chunk.choices[0].delta # type: ignore
if isinstance(delta_chunk, str):
aggregated_response += delta_chunk
elif delta_chunk.content:
aggregated_response += delta_chunk.content
# Calculate cost of chat
input_tokens = chunk.usage.prompt_tokens if hasattr(chunk, "usage") and chunk.usage else 0
output_tokens = chunk.usage.completion_tokens if hasattr(chunk, "usage") and chunk.usage else 0
tracer["usage"] = get_chat_usage_metrics(model, input_tokens, output_tokens, tracer.get("usage"))
# Save conversation trace # Save conversation trace
tracer["chat_model"] = model tracer["chat_model"] = model
@ -164,10 +173,11 @@ def llm_thread(
if os.getenv("KHOJ_LLM_SEED"): if os.getenv("KHOJ_LLM_SEED"):
model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED")) model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED"))
chat = client.chat.completions.create( chat: ChatCompletion | openai.Stream[ChatCompletionChunk] = client.chat.completions.create(
stream=stream,
messages=formatted_messages, messages=formatted_messages,
model=model_name, # type: ignore model=model_name, # type: ignore
stream=stream,
stream_options={"include_usage": True} if stream else {},
temperature=temperature, temperature=temperature,
timeout=20, timeout=20,
**(model_kwargs or dict()), **(model_kwargs or dict()),
@ -175,7 +185,8 @@ def llm_thread(
aggregated_response = "" aggregated_response = ""
if not stream: if not stream:
aggregated_response = chat.choices[0].message.content chunk = chat
aggregated_response = chunk.choices[0].message.content
g.send(aggregated_response) g.send(aggregated_response)
else: else:
for chunk in chat: for chunk in chat:
@ -191,6 +202,11 @@ def llm_thread(
aggregated_response += text_chunk aggregated_response += text_chunk
g.send(text_chunk) g.send(text_chunk)
# Calculate cost of chat
input_tokens = chunk.usage.prompt_tokens if hasattr(chunk, "usage") and chunk.usage else 0
output_tokens = chunk.usage.completion_tokens if hasattr(chunk, "usage") and chunk.usage else 0
tracer["usage"] = get_chat_usage_metrics(model_name, input_tokens, output_tokens, tracer.get("usage"))
# Save conversation trace # Save conversation trace
tracer["chat_model"] = model_name tracer["chat_model"] = model_name
tracer["temperature"] = temperature tracer["temperature"] = temperature

View file

@ -5,7 +5,6 @@ import math
import mimetypes import mimetypes
import os import os
import queue import queue
import re
import uuid import uuid
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime from datetime import datetime
@ -57,7 +56,7 @@ model_to_prompt_size = {
"gemini-1.5-flash": 20000, "gemini-1.5-flash": 20000,
"gemini-1.5-pro": 20000, "gemini-1.5-pro": 20000,
# Anthropic Models # Anthropic Models
"claude-3-5-sonnet-20240620": 20000, "claude-3-5-sonnet-20241022": 20000,
"claude-3-5-haiku-20241022": 20000, "claude-3-5-haiku-20241022": 20000,
# Offline Models # Offline Models
"bartowski/Meta-Llama-3.1-8B-Instruct-GGUF": 20000, "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF": 20000,
@ -213,6 +212,8 @@ class ChatEvent(Enum):
REFERENCES = "references" REFERENCES = "references"
STATUS = "status" STATUS = "status"
METADATA = "metadata" METADATA = "metadata"
USAGE = "usage"
END_RESPONSE = "end_response"
def message_to_log( def message_to_log(

View file

@ -667,27 +667,37 @@ async def chat(
finally: finally:
yield event_delimiter yield event_delimiter
async def send_llm_response(response: str): async def send_llm_response(response: str, usage: dict = None):
# Send Chat Response
async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""): async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""):
yield result yield result
async for result in send_event(ChatEvent.MESSAGE, response): async for result in send_event(ChatEvent.MESSAGE, response):
yield result yield result
async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""): async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""):
yield result yield result
# Send Usage Metadata once llm interactions are complete
if usage:
async for event in send_event(ChatEvent.USAGE, usage):
yield event
async for result in send_event(ChatEvent.END_RESPONSE, ""):
yield result
def collect_telemetry(): def collect_telemetry():
# Gather chat response telemetry # Gather chat response telemetry
nonlocal chat_metadata nonlocal chat_metadata
latency = time.perf_counter() - start_time latency = time.perf_counter() - start_time
cmd_set = set([cmd.value for cmd in conversation_commands]) cmd_set = set([cmd.value for cmd in conversation_commands])
cost = (tracer.get("usage", {}) or {}).get("cost", 0)
chat_metadata = chat_metadata or {} chat_metadata = chat_metadata or {}
chat_metadata["conversation_command"] = cmd_set chat_metadata["conversation_command"] = cmd_set
chat_metadata["agent"] = conversation.agent.slug if conversation.agent else None chat_metadata["agent"] = conversation.agent.slug if conversation.agent else None
chat_metadata["latency"] = f"{latency:.3f}" chat_metadata["latency"] = f"{latency:.3f}"
chat_metadata["ttft_latency"] = f"{ttft:.3f}" chat_metadata["ttft_latency"] = f"{ttft:.3f}"
chat_metadata["usage"] = tracer.get("usage")
logger.info(f"Chat response time to first token: {ttft:.3f} seconds") logger.info(f"Chat response time to first token: {ttft:.3f} seconds")
logger.info(f"Chat response total time: {latency:.3f} seconds") logger.info(f"Chat response total time: {latency:.3f} seconds")
logger.info(f"Chat response cost: ${cost:.5f}")
update_telemetry_state( update_telemetry_state(
request=request, request=request,
telemetry_type="api", telemetry_type="api",
@ -699,7 +709,7 @@ async def chat(
) )
if is_query_empty(q): if is_query_empty(q):
async for result in send_llm_response("Please ask your query to get started."): async for result in send_llm_response("Please ask your query to get started.", tracer.get("usage")):
yield result yield result
return return
@ -713,7 +723,7 @@ async def chat(
create_new=body.create_new, create_new=body.create_new,
) )
if not conversation: if not conversation:
async for result in send_llm_response(f"Conversation {conversation_id} not found"): async for result in send_llm_response(f"Conversation {conversation_id} not found", tracer.get("usage")):
yield result yield result
return return
conversation_id = conversation.id conversation_id = conversation.id
@ -777,7 +787,7 @@ async def chat(
await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd) await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd)
q = q.replace(f"/{cmd.value}", "").strip() q = q.replace(f"/{cmd.value}", "").strip()
except HTTPException as e: except HTTPException as e:
async for result in send_llm_response(str(e.detail)): async for result in send_llm_response(str(e.detail), tracer.get("usage")):
yield result yield result
return return
@ -834,7 +844,7 @@ async def chat(
agent_has_entries = await EntryAdapters.aagent_has_entries(agent) agent_has_entries = await EntryAdapters.aagent_has_entries(agent)
if len(file_filters) == 0 and not agent_has_entries: if len(file_filters) == 0 and not agent_has_entries:
response_log = "No files selected for summarization. Please add files using the section on the left." response_log = "No files selected for summarization. Please add files using the section on the left."
async for result in send_llm_response(response_log): async for result in send_llm_response(response_log, tracer.get("usage")):
yield result yield result
else: else:
async for response in generate_summary_from_files( async for response in generate_summary_from_files(
@ -853,7 +863,7 @@ async def chat(
else: else:
if isinstance(response, str): if isinstance(response, str):
response_log = response response_log = response
async for result in send_llm_response(response): async for result in send_llm_response(response, tracer.get("usage")):
yield result yield result
await sync_to_async(save_to_conversation_log)( await sync_to_async(save_to_conversation_log)(
@ -880,7 +890,7 @@ async def chat(
conversation_config = await ConversationAdapters.aget_default_conversation_config(user) conversation_config = await ConversationAdapters.aget_default_conversation_config(user)
model_type = conversation_config.model_type model_type = conversation_config.model_type
formatted_help = help_message.format(model=model_type, version=state.khoj_version, device=get_device()) formatted_help = help_message.format(model=model_type, version=state.khoj_version, device=get_device())
async for result in send_llm_response(formatted_help): async for result in send_llm_response(formatted_help, tracer.get("usage")):
yield result yield result
return return
# Adding specification to search online specifically on khoj.dev pages. # Adding specification to search online specifically on khoj.dev pages.
@ -895,7 +905,7 @@ async def chat(
except Exception as e: except Exception as e:
logger.error(f"Error scheduling task {q} for {user.email}: {e}") logger.error(f"Error scheduling task {q} for {user.email}: {e}")
error_message = f"Unable to create automation. Ensure the automation doesn't already exist." error_message = f"Unable to create automation. Ensure the automation doesn't already exist."
async for result in send_llm_response(error_message): async for result in send_llm_response(error_message, tracer.get("usage")):
yield result yield result
return return
@ -916,7 +926,7 @@ async def chat(
raw_query_files=raw_query_files, raw_query_files=raw_query_files,
tracer=tracer, tracer=tracer,
) )
async for result in send_llm_response(llm_response): async for result in send_llm_response(llm_response, tracer.get("usage")):
yield result yield result
return return
@ -963,7 +973,7 @@ async def chat(
yield result yield result
if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user): if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user):
async for result in send_llm_response(f"{no_entries_found.format()}"): async for result in send_llm_response(f"{no_entries_found.format()}", tracer.get("usage")):
yield result yield result
return return
@ -1105,7 +1115,7 @@ async def chat(
"detail": improved_image_prompt, "detail": improved_image_prompt,
"image": None, "image": None,
} }
async for result in send_llm_response(json.dumps(content_obj)): async for result in send_llm_response(json.dumps(content_obj), tracer.get("usage")):
yield result yield result
return return
@ -1132,7 +1142,7 @@ async def chat(
"inferredQueries": [improved_image_prompt], "inferredQueries": [improved_image_prompt],
"image": generated_image, "image": generated_image,
} }
async for result in send_llm_response(json.dumps(content_obj)): async for result in send_llm_response(json.dumps(content_obj), tracer.get("usage")):
yield result yield result
return return
@ -1166,7 +1176,7 @@ async def chat(
diagram_description = excalidraw_diagram_description diagram_description = excalidraw_diagram_description
else: else:
error_message = "Failed to generate diagram. Please try again later." error_message = "Failed to generate diagram. Please try again later."
async for result in send_llm_response(error_message): async for result in send_llm_response(error_message, tracer.get("usage")):
yield result yield result
await sync_to_async(save_to_conversation_log)( await sync_to_async(save_to_conversation_log)(
@ -1213,7 +1223,7 @@ async def chat(
tracer=tracer, tracer=tracer,
) )
async for result in send_llm_response(json.dumps(content_obj)): async for result in send_llm_response(json.dumps(content_obj), tracer.get("usage")):
yield result yield result
return return
@ -1252,6 +1262,11 @@ async def chat(
if item is None: if item is None:
async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""): async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""):
yield result yield result
# Send Usage Metadata once llm interactions are complete
async for event in send_event(ChatEvent.USAGE, tracer.get("usage")):
yield event
async for result in send_event(ChatEvent.END_RESPONSE, ""):
yield result
logger.debug("Finished streaming response") logger.debug("Finished streaming response")
return return
if not connection_alive or not continue_stream: if not connection_alive or not continue_stream:

View file

@ -1770,6 +1770,7 @@ Manage your automations [here](/automations).
class MessageProcessor: class MessageProcessor:
def __init__(self): def __init__(self):
self.references = {} self.references = {}
self.usage = {}
self.raw_response = "" self.raw_response = ""
def convert_message_chunk_to_json(self, raw_chunk: str) -> Dict[str, Any]: def convert_message_chunk_to_json(self, raw_chunk: str) -> Dict[str, Any]:
@ -1793,6 +1794,8 @@ class MessageProcessor:
chunk_type = ChatEvent(chunk["type"]) chunk_type = ChatEvent(chunk["type"])
if chunk_type == ChatEvent.REFERENCES: if chunk_type == ChatEvent.REFERENCES:
self.references = chunk["data"] self.references = chunk["data"]
elif chunk_type == ChatEvent.USAGE:
self.usage = chunk["data"]
elif chunk_type == ChatEvent.MESSAGE: elif chunk_type == ChatEvent.MESSAGE:
chunk_data = chunk["data"] chunk_data = chunk["data"]
if isinstance(chunk_data, dict): if isinstance(chunk_data, dict):
@ -1837,7 +1840,7 @@ async def read_chat_stream(response_iterator: AsyncGenerator[str, None]) -> Dict
if buffer: if buffer:
processor.process_message_chunk(buffer) processor.process_message_chunk(buffer)
return {"response": processor.raw_response, "references": processor.references} return {"response": processor.raw_response, "references": processor.references, "usage": processor.usage}
def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False): def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False):

View file

@ -1,4 +1,5 @@
from pathlib import Path from pathlib import Path
from typing import Dict
app_root_directory = Path(__file__).parent.parent.parent app_root_directory = Path(__file__).parent.parent.parent
web_directory = app_root_directory / "khoj/interface/web/" web_directory = app_root_directory / "khoj/interface/web/"
@ -31,3 +32,19 @@ default_config = {
"image": {"encoder": "sentence-transformers/clip-ViT-B-32", "model_directory": "~/.khoj/search/image/"}, "image": {"encoder": "sentence-transformers/clip-ViT-B-32", "model_directory": "~/.khoj/search/image/"},
}, },
} }
model_to_cost: Dict[str, Dict[str, float]] = {
# OpenAI Pricing: https://openai.com/api/pricing/
"gpt-4o": {"input": 2.50, "output": 10.00},
"gpt-4o-mini": {"input": 0.15, "output": 0.60},
"o1-preview": {"input": 15.0, "output": 60.00},
"o1-mini": {"input": 3.0, "output": 12.0},
# Gemini Pricing: https://ai.google.dev/pricing
"gemini-1.5-flash": {"input": 0.075, "output": 0.30},
"gemini-1.5-flash-002": {"input": 0.075, "output": 0.30},
"gemini-1.5-pro": {"input": 1.25, "output": 5.00},
"gemini-1.5-pro-002": {"input": 1.25, "output": 5.00},
# Anthropic Pricing: https://www.anthropic.com/pricing#anthropic-api_
"claude-3-5-sonnet-20241022": {"input": 3.0, "output": 15.0},
"claude-3-5-haiku-20241022": {"input": 1.0, "output": 5.0},
}

View file

@ -540,3 +540,27 @@ def get_country_code_from_timezone(tz: str) -> str:
def get_country_name_from_timezone(tz: str) -> str: def get_country_name_from_timezone(tz: str) -> str:
"""Get country name from timezone""" """Get country name from timezone"""
return country_names.get(get_country_code_from_timezone(tz), "United States") return country_names.get(get_country_code_from_timezone(tz), "United States")
def get_cost_of_chat_message(model_name: str, input_tokens: int = 0, output_tokens: int = 0, prev_cost: float = 0.0):
"""
Calculate cost of chat message based on input and output tokens
"""
# Calculate cost of input and output tokens. Costs are per million tokens
input_cost = constants.model_to_cost.get(model_name, {}).get("input", 0) * (input_tokens / 1e6)
output_cost = constants.model_to_cost.get(model_name, {}).get("output", 0) * (output_tokens / 1e6)
return input_cost + output_cost + prev_cost
def get_chat_usage_metrics(model_name: str, input_tokens: int = 0, output_tokens: int = 0, usage: dict = {}):
"""
Get usage metrics for chat message based on input and output tokens
"""
prev_usage = usage or {"input_tokens": 0, "output_tokens": 0, "cost": 0.0}
return {
"input_tokens": prev_usage["input_tokens"] + input_tokens,
"output_tokens": prev_usage["output_tokens"] + output_tokens,
"cost": get_cost_of_chat_message(model_name, input_tokens, output_tokens, prev_cost=prev_usage["cost"]),
}