mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 15:38:55 +01:00
Track, return cost and usage metrics in chat api response
- Track input, output token usage and cost for interactions via chat api with openai, anthropic and google chat models - Get usage metadata from OpenAI using stream_options - Handle openai proxies that do not support passing usage in response - Add new usage, end response events returned by chat api. - This can be optionally consumed by clients at a later point - Update streaming clients to mark message as completed after new end response event, not after end llm response event - Ensure usage data from final response generation step is included - Pass usage data after llm response complete. This allows gathering token usage and cost for the final response generation step across streaming and non-streaming modes
This commit is contained in:
parent
7bdc9590dd
commit
c53c3db96b
10 changed files with 139 additions and 38 deletions
|
@ -945,7 +945,7 @@ export class KhojChatView extends KhojPaneView {
|
|||
console.log("Started streaming", new Date());
|
||||
} else if (chunk.type === 'end_llm_response') {
|
||||
console.log("Stopped streaming", new Date());
|
||||
|
||||
} else if (chunk.type === 'end_response') {
|
||||
// Automatically respond with voice if the subscribed user has sent voice message
|
||||
if (this.chatMessageState.isVoice && this.setting.userInfo?.is_active)
|
||||
this.textToSpeech(this.chatMessageState.rawResponse);
|
||||
|
|
|
@ -133,7 +133,7 @@ export function processMessageChunk(
|
|||
console.log(`Started streaming: ${new Date()}`);
|
||||
} else if (chunk.type === "end_llm_response") {
|
||||
console.log(`Completed streaming: ${new Date()}`);
|
||||
|
||||
} else if (chunk.type === "end_response") {
|
||||
// Append any references after all the data has been streamed
|
||||
if (codeContext) currentMessage.codeContext = codeContext;
|
||||
if (onlineContext) currentMessage.onlineContext = onlineContext;
|
||||
|
|
|
@ -18,7 +18,7 @@ from khoj.processor.conversation.utils import (
|
|||
get_image_from_url,
|
||||
)
|
||||
from khoj.utils import state
|
||||
from khoj.utils.helpers import in_debug_mode, is_none_or_empty
|
||||
from khoj.utils.helpers import get_chat_usage_metrics, in_debug_mode, is_none_or_empty
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -59,6 +59,7 @@ def anthropic_completion_with_backoff(
|
|||
aggregated_response = "{" if response_type == "json_object" else ""
|
||||
max_tokens = max_tokens or DEFAULT_MAX_TOKENS_ANTHROPIC
|
||||
|
||||
final_message = None
|
||||
model_kwargs = model_kwargs or dict()
|
||||
if system_prompt:
|
||||
model_kwargs["system"] = system_prompt
|
||||
|
@ -73,6 +74,12 @@ def anthropic_completion_with_backoff(
|
|||
) as stream:
|
||||
for text in stream.text_stream:
|
||||
aggregated_response += text
|
||||
final_message = stream.get_final_message()
|
||||
|
||||
# Calculate cost of chat
|
||||
input_tokens = final_message.usage.input_tokens
|
||||
output_tokens = final_message.usage.output_tokens
|
||||
tracer["usage"] = get_chat_usage_metrics(model_name, input_tokens, output_tokens, tracer.get("usage"))
|
||||
|
||||
# Save conversation trace
|
||||
tracer["chat_model"] = model_name
|
||||
|
@ -126,6 +133,7 @@ def anthropic_llm_thread(
|
|||
]
|
||||
|
||||
aggregated_response = ""
|
||||
final_message = None
|
||||
with client.messages.stream(
|
||||
messages=formatted_messages,
|
||||
model=model_name, # type: ignore
|
||||
|
@ -138,6 +146,12 @@ def anthropic_llm_thread(
|
|||
for text in stream.text_stream:
|
||||
aggregated_response += text
|
||||
g.send(text)
|
||||
final_message = stream.get_final_message()
|
||||
|
||||
# Calculate cost of chat
|
||||
input_tokens = final_message.usage.input_tokens
|
||||
output_tokens = final_message.usage.output_tokens
|
||||
tracer["usage"] = get_chat_usage_metrics(model_name, input_tokens, output_tokens, tracer.get("usage"))
|
||||
|
||||
# Save conversation trace
|
||||
tracer["chat_model"] = model_name
|
||||
|
|
|
@ -25,7 +25,7 @@ from khoj.processor.conversation.utils import (
|
|||
get_image_from_url,
|
||||
)
|
||||
from khoj.utils import state
|
||||
from khoj.utils.helpers import in_debug_mode, is_none_or_empty
|
||||
from khoj.utils.helpers import get_chat_usage_metrics, in_debug_mode, is_none_or_empty
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -68,6 +68,7 @@ def gemini_completion_with_backoff(
|
|||
response = chat_session.send_message(formatted_messages[-1]["parts"])
|
||||
response_text = response.text
|
||||
except StopCandidateException as e:
|
||||
response = None
|
||||
response_text, _ = handle_gemini_response(e.args)
|
||||
# Respond with reason for stopping
|
||||
logger.warning(
|
||||
|
@ -75,6 +76,11 @@ def gemini_completion_with_backoff(
|
|||
+ f"Last Message by {messages[-1].role}: {messages[-1].content}"
|
||||
)
|
||||
|
||||
# Aggregate cost of chat
|
||||
input_tokens = response.usage_metadata.prompt_token_count if response else 0
|
||||
output_tokens = response.usage_metadata.candidates_token_count if response else 0
|
||||
tracer["usage"] = get_chat_usage_metrics(model_name, input_tokens, output_tokens, tracer.get("usage"))
|
||||
|
||||
# Save conversation trace
|
||||
tracer["chat_model"] = model_name
|
||||
tracer["temperature"] = temperature
|
||||
|
@ -146,6 +152,11 @@ def gemini_llm_thread(
|
|||
if stopped:
|
||||
raise StopCandidateException(message)
|
||||
|
||||
# Calculate cost of chat
|
||||
input_tokens = chunk.usage_metadata.prompt_token_count
|
||||
output_tokens = chunk.usage_metadata.candidates_token_count
|
||||
tracer["usage"] = get_chat_usage_metrics(model_name, input_tokens, output_tokens, tracer.get("usage"))
|
||||
|
||||
# Save conversation trace
|
||||
tracer["chat_model"] = model_name
|
||||
tracer["temperature"] = temperature
|
||||
|
|
|
@ -4,6 +4,8 @@ from threading import Thread
|
|||
from typing import Dict
|
||||
|
||||
import openai
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
from tenacity import (
|
||||
before_sleep_log,
|
||||
retry,
|
||||
|
@ -18,7 +20,7 @@ from khoj.processor.conversation.utils import (
|
|||
commit_conversation_trace,
|
||||
)
|
||||
from khoj.utils import state
|
||||
from khoj.utils.helpers import in_debug_mode
|
||||
from khoj.utils.helpers import get_chat_usage_metrics, in_debug_mode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -64,27 +66,34 @@ def completion_with_backoff(
|
|||
if os.getenv("KHOJ_LLM_SEED"):
|
||||
model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED"))
|
||||
|
||||
chat = client.chat.completions.create(
|
||||
stream=stream,
|
||||
chat: ChatCompletion | openai.Stream[ChatCompletionChunk] = client.chat.completions.create(
|
||||
messages=formatted_messages, # type: ignore
|
||||
model=model, # type: ignore
|
||||
stream=stream,
|
||||
stream_options={"include_usage": True} if stream else {},
|
||||
temperature=temperature,
|
||||
timeout=20,
|
||||
**(model_kwargs or dict()),
|
||||
)
|
||||
|
||||
if not stream:
|
||||
return chat.choices[0].message.content
|
||||
|
||||
aggregated_response = ""
|
||||
for chunk in chat:
|
||||
if len(chunk.choices) == 0:
|
||||
continue
|
||||
delta_chunk = chunk.choices[0].delta # type: ignore
|
||||
if isinstance(delta_chunk, str):
|
||||
aggregated_response += delta_chunk
|
||||
elif delta_chunk.content:
|
||||
aggregated_response += delta_chunk.content
|
||||
if not stream:
|
||||
chunk = chat
|
||||
aggregated_response = chunk.choices[0].message.content
|
||||
else:
|
||||
for chunk in chat:
|
||||
if len(chunk.choices) == 0:
|
||||
continue
|
||||
delta_chunk = chunk.choices[0].delta # type: ignore
|
||||
if isinstance(delta_chunk, str):
|
||||
aggregated_response += delta_chunk
|
||||
elif delta_chunk.content:
|
||||
aggregated_response += delta_chunk.content
|
||||
|
||||
# Calculate cost of chat
|
||||
input_tokens = chunk.usage.prompt_tokens if hasattr(chunk, "usage") and chunk.usage else 0
|
||||
output_tokens = chunk.usage.completion_tokens if hasattr(chunk, "usage") and chunk.usage else 0
|
||||
tracer["usage"] = get_chat_usage_metrics(model, input_tokens, output_tokens, tracer.get("usage"))
|
||||
|
||||
# Save conversation trace
|
||||
tracer["chat_model"] = model
|
||||
|
@ -164,10 +173,11 @@ def llm_thread(
|
|||
if os.getenv("KHOJ_LLM_SEED"):
|
||||
model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED"))
|
||||
|
||||
chat = client.chat.completions.create(
|
||||
stream=stream,
|
||||
chat: ChatCompletion | openai.Stream[ChatCompletionChunk] = client.chat.completions.create(
|
||||
messages=formatted_messages,
|
||||
model=model_name, # type: ignore
|
||||
stream=stream,
|
||||
stream_options={"include_usage": True} if stream else {},
|
||||
temperature=temperature,
|
||||
timeout=20,
|
||||
**(model_kwargs or dict()),
|
||||
|
@ -175,7 +185,8 @@ def llm_thread(
|
|||
|
||||
aggregated_response = ""
|
||||
if not stream:
|
||||
aggregated_response = chat.choices[0].message.content
|
||||
chunk = chat
|
||||
aggregated_response = chunk.choices[0].message.content
|
||||
g.send(aggregated_response)
|
||||
else:
|
||||
for chunk in chat:
|
||||
|
@ -191,6 +202,11 @@ def llm_thread(
|
|||
aggregated_response += text_chunk
|
||||
g.send(text_chunk)
|
||||
|
||||
# Calculate cost of chat
|
||||
input_tokens = chunk.usage.prompt_tokens if hasattr(chunk, "usage") and chunk.usage else 0
|
||||
output_tokens = chunk.usage.completion_tokens if hasattr(chunk, "usage") and chunk.usage else 0
|
||||
tracer["usage"] = get_chat_usage_metrics(model_name, input_tokens, output_tokens, tracer.get("usage"))
|
||||
|
||||
# Save conversation trace
|
||||
tracer["chat_model"] = model_name
|
||||
tracer["temperature"] = temperature
|
||||
|
|
|
@ -5,7 +5,6 @@ import math
|
|||
import mimetypes
|
||||
import os
|
||||
import queue
|
||||
import re
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
|
@ -57,7 +56,7 @@ model_to_prompt_size = {
|
|||
"gemini-1.5-flash": 20000,
|
||||
"gemini-1.5-pro": 20000,
|
||||
# Anthropic Models
|
||||
"claude-3-5-sonnet-20240620": 20000,
|
||||
"claude-3-5-sonnet-20241022": 20000,
|
||||
"claude-3-5-haiku-20241022": 20000,
|
||||
# Offline Models
|
||||
"bartowski/Meta-Llama-3.1-8B-Instruct-GGUF": 20000,
|
||||
|
@ -213,6 +212,8 @@ class ChatEvent(Enum):
|
|||
REFERENCES = "references"
|
||||
STATUS = "status"
|
||||
METADATA = "metadata"
|
||||
USAGE = "usage"
|
||||
END_RESPONSE = "end_response"
|
||||
|
||||
|
||||
def message_to_log(
|
||||
|
|
|
@ -667,27 +667,37 @@ async def chat(
|
|||
finally:
|
||||
yield event_delimiter
|
||||
|
||||
async def send_llm_response(response: str):
|
||||
async def send_llm_response(response: str, usage: dict = None):
|
||||
# Send Chat Response
|
||||
async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""):
|
||||
yield result
|
||||
async for result in send_event(ChatEvent.MESSAGE, response):
|
||||
yield result
|
||||
async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""):
|
||||
yield result
|
||||
# Send Usage Metadata once llm interactions are complete
|
||||
if usage:
|
||||
async for event in send_event(ChatEvent.USAGE, usage):
|
||||
yield event
|
||||
async for result in send_event(ChatEvent.END_RESPONSE, ""):
|
||||
yield result
|
||||
|
||||
def collect_telemetry():
|
||||
# Gather chat response telemetry
|
||||
nonlocal chat_metadata
|
||||
latency = time.perf_counter() - start_time
|
||||
cmd_set = set([cmd.value for cmd in conversation_commands])
|
||||
cost = (tracer.get("usage", {}) or {}).get("cost", 0)
|
||||
chat_metadata = chat_metadata or {}
|
||||
chat_metadata["conversation_command"] = cmd_set
|
||||
chat_metadata["agent"] = conversation.agent.slug if conversation.agent else None
|
||||
chat_metadata["latency"] = f"{latency:.3f}"
|
||||
chat_metadata["ttft_latency"] = f"{ttft:.3f}"
|
||||
chat_metadata["usage"] = tracer.get("usage")
|
||||
|
||||
logger.info(f"Chat response time to first token: {ttft:.3f} seconds")
|
||||
logger.info(f"Chat response total time: {latency:.3f} seconds")
|
||||
logger.info(f"Chat response cost: ${cost:.5f}")
|
||||
update_telemetry_state(
|
||||
request=request,
|
||||
telemetry_type="api",
|
||||
|
@ -699,7 +709,7 @@ async def chat(
|
|||
)
|
||||
|
||||
if is_query_empty(q):
|
||||
async for result in send_llm_response("Please ask your query to get started."):
|
||||
async for result in send_llm_response("Please ask your query to get started.", tracer.get("usage")):
|
||||
yield result
|
||||
return
|
||||
|
||||
|
@ -713,7 +723,7 @@ async def chat(
|
|||
create_new=body.create_new,
|
||||
)
|
||||
if not conversation:
|
||||
async for result in send_llm_response(f"Conversation {conversation_id} not found"):
|
||||
async for result in send_llm_response(f"Conversation {conversation_id} not found", tracer.get("usage")):
|
||||
yield result
|
||||
return
|
||||
conversation_id = conversation.id
|
||||
|
@ -777,7 +787,7 @@ async def chat(
|
|||
await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd)
|
||||
q = q.replace(f"/{cmd.value}", "").strip()
|
||||
except HTTPException as e:
|
||||
async for result in send_llm_response(str(e.detail)):
|
||||
async for result in send_llm_response(str(e.detail), tracer.get("usage")):
|
||||
yield result
|
||||
return
|
||||
|
||||
|
@ -834,7 +844,7 @@ async def chat(
|
|||
agent_has_entries = await EntryAdapters.aagent_has_entries(agent)
|
||||
if len(file_filters) == 0 and not agent_has_entries:
|
||||
response_log = "No files selected for summarization. Please add files using the section on the left."
|
||||
async for result in send_llm_response(response_log):
|
||||
async for result in send_llm_response(response_log, tracer.get("usage")):
|
||||
yield result
|
||||
else:
|
||||
async for response in generate_summary_from_files(
|
||||
|
@ -853,7 +863,7 @@ async def chat(
|
|||
else:
|
||||
if isinstance(response, str):
|
||||
response_log = response
|
||||
async for result in send_llm_response(response):
|
||||
async for result in send_llm_response(response, tracer.get("usage")):
|
||||
yield result
|
||||
|
||||
await sync_to_async(save_to_conversation_log)(
|
||||
|
@ -880,7 +890,7 @@ async def chat(
|
|||
conversation_config = await ConversationAdapters.aget_default_conversation_config(user)
|
||||
model_type = conversation_config.model_type
|
||||
formatted_help = help_message.format(model=model_type, version=state.khoj_version, device=get_device())
|
||||
async for result in send_llm_response(formatted_help):
|
||||
async for result in send_llm_response(formatted_help, tracer.get("usage")):
|
||||
yield result
|
||||
return
|
||||
# Adding specification to search online specifically on khoj.dev pages.
|
||||
|
@ -895,7 +905,7 @@ async def chat(
|
|||
except Exception as e:
|
||||
logger.error(f"Error scheduling task {q} for {user.email}: {e}")
|
||||
error_message = f"Unable to create automation. Ensure the automation doesn't already exist."
|
||||
async for result in send_llm_response(error_message):
|
||||
async for result in send_llm_response(error_message, tracer.get("usage")):
|
||||
yield result
|
||||
return
|
||||
|
||||
|
@ -916,7 +926,7 @@ async def chat(
|
|||
raw_query_files=raw_query_files,
|
||||
tracer=tracer,
|
||||
)
|
||||
async for result in send_llm_response(llm_response):
|
||||
async for result in send_llm_response(llm_response, tracer.get("usage")):
|
||||
yield result
|
||||
return
|
||||
|
||||
|
@ -963,7 +973,7 @@ async def chat(
|
|||
yield result
|
||||
|
||||
if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user):
|
||||
async for result in send_llm_response(f"{no_entries_found.format()}"):
|
||||
async for result in send_llm_response(f"{no_entries_found.format()}", tracer.get("usage")):
|
||||
yield result
|
||||
return
|
||||
|
||||
|
@ -1105,7 +1115,7 @@ async def chat(
|
|||
"detail": improved_image_prompt,
|
||||
"image": None,
|
||||
}
|
||||
async for result in send_llm_response(json.dumps(content_obj)):
|
||||
async for result in send_llm_response(json.dumps(content_obj), tracer.get("usage")):
|
||||
yield result
|
||||
return
|
||||
|
||||
|
@ -1132,7 +1142,7 @@ async def chat(
|
|||
"inferredQueries": [improved_image_prompt],
|
||||
"image": generated_image,
|
||||
}
|
||||
async for result in send_llm_response(json.dumps(content_obj)):
|
||||
async for result in send_llm_response(json.dumps(content_obj), tracer.get("usage")):
|
||||
yield result
|
||||
return
|
||||
|
||||
|
@ -1166,7 +1176,7 @@ async def chat(
|
|||
diagram_description = excalidraw_diagram_description
|
||||
else:
|
||||
error_message = "Failed to generate diagram. Please try again later."
|
||||
async for result in send_llm_response(error_message):
|
||||
async for result in send_llm_response(error_message, tracer.get("usage")):
|
||||
yield result
|
||||
|
||||
await sync_to_async(save_to_conversation_log)(
|
||||
|
@ -1213,7 +1223,7 @@ async def chat(
|
|||
tracer=tracer,
|
||||
)
|
||||
|
||||
async for result in send_llm_response(json.dumps(content_obj)):
|
||||
async for result in send_llm_response(json.dumps(content_obj), tracer.get("usage")):
|
||||
yield result
|
||||
return
|
||||
|
||||
|
@ -1252,6 +1262,11 @@ async def chat(
|
|||
if item is None:
|
||||
async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""):
|
||||
yield result
|
||||
# Send Usage Metadata once llm interactions are complete
|
||||
async for event in send_event(ChatEvent.USAGE, tracer.get("usage")):
|
||||
yield event
|
||||
async for result in send_event(ChatEvent.END_RESPONSE, ""):
|
||||
yield result
|
||||
logger.debug("Finished streaming response")
|
||||
return
|
||||
if not connection_alive or not continue_stream:
|
||||
|
|
|
@ -1770,6 +1770,7 @@ Manage your automations [here](/automations).
|
|||
class MessageProcessor:
|
||||
def __init__(self):
|
||||
self.references = {}
|
||||
self.usage = {}
|
||||
self.raw_response = ""
|
||||
|
||||
def convert_message_chunk_to_json(self, raw_chunk: str) -> Dict[str, Any]:
|
||||
|
@ -1793,6 +1794,8 @@ class MessageProcessor:
|
|||
chunk_type = ChatEvent(chunk["type"])
|
||||
if chunk_type == ChatEvent.REFERENCES:
|
||||
self.references = chunk["data"]
|
||||
elif chunk_type == ChatEvent.USAGE:
|
||||
self.usage = chunk["data"]
|
||||
elif chunk_type == ChatEvent.MESSAGE:
|
||||
chunk_data = chunk["data"]
|
||||
if isinstance(chunk_data, dict):
|
||||
|
@ -1837,7 +1840,7 @@ async def read_chat_stream(response_iterator: AsyncGenerator[str, None]) -> Dict
|
|||
if buffer:
|
||||
processor.process_message_chunk(buffer)
|
||||
|
||||
return {"response": processor.raw_response, "references": processor.references}
|
||||
return {"response": processor.raw_response, "references": processor.references, "usage": processor.usage}
|
||||
|
||||
|
||||
def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False):
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
app_root_directory = Path(__file__).parent.parent.parent
|
||||
web_directory = app_root_directory / "khoj/interface/web/"
|
||||
|
@ -31,3 +32,19 @@ default_config = {
|
|||
"image": {"encoder": "sentence-transformers/clip-ViT-B-32", "model_directory": "~/.khoj/search/image/"},
|
||||
},
|
||||
}
|
||||
|
||||
model_to_cost: Dict[str, Dict[str, float]] = {
|
||||
# OpenAI Pricing: https://openai.com/api/pricing/
|
||||
"gpt-4o": {"input": 2.50, "output": 10.00},
|
||||
"gpt-4o-mini": {"input": 0.15, "output": 0.60},
|
||||
"o1-preview": {"input": 15.0, "output": 60.00},
|
||||
"o1-mini": {"input": 3.0, "output": 12.0},
|
||||
# Gemini Pricing: https://ai.google.dev/pricing
|
||||
"gemini-1.5-flash": {"input": 0.075, "output": 0.30},
|
||||
"gemini-1.5-flash-002": {"input": 0.075, "output": 0.30},
|
||||
"gemini-1.5-pro": {"input": 1.25, "output": 5.00},
|
||||
"gemini-1.5-pro-002": {"input": 1.25, "output": 5.00},
|
||||
# Anthropic Pricing: https://www.anthropic.com/pricing#anthropic-api_
|
||||
"claude-3-5-sonnet-20241022": {"input": 3.0, "output": 15.0},
|
||||
"claude-3-5-haiku-20241022": {"input": 1.0, "output": 5.0},
|
||||
}
|
||||
|
|
|
@ -540,3 +540,27 @@ def get_country_code_from_timezone(tz: str) -> str:
|
|||
def get_country_name_from_timezone(tz: str) -> str:
|
||||
"""Get country name from timezone"""
|
||||
return country_names.get(get_country_code_from_timezone(tz), "United States")
|
||||
|
||||
|
||||
def get_cost_of_chat_message(model_name: str, input_tokens: int = 0, output_tokens: int = 0, prev_cost: float = 0.0):
|
||||
"""
|
||||
Calculate cost of chat message based on input and output tokens
|
||||
"""
|
||||
|
||||
# Calculate cost of input and output tokens. Costs are per million tokens
|
||||
input_cost = constants.model_to_cost.get(model_name, {}).get("input", 0) * (input_tokens / 1e6)
|
||||
output_cost = constants.model_to_cost.get(model_name, {}).get("output", 0) * (output_tokens / 1e6)
|
||||
|
||||
return input_cost + output_cost + prev_cost
|
||||
|
||||
|
||||
def get_chat_usage_metrics(model_name: str, input_tokens: int = 0, output_tokens: int = 0, usage: dict = {}):
|
||||
"""
|
||||
Get usage metrics for chat message based on input and output tokens
|
||||
"""
|
||||
prev_usage = usage or {"input_tokens": 0, "output_tokens": 0, "cost": 0.0}
|
||||
return {
|
||||
"input_tokens": prev_usage["input_tokens"] + input_tokens,
|
||||
"output_tokens": prev_usage["output_tokens"] + output_tokens,
|
||||
"cost": get_cost_of_chat_message(model_name, input_tokens, output_tokens, prev_cost=prev_usage["cost"]),
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue