Track, return cost and usage metrics in chat api response

- Track input, output token usage and cost for interactions
  via chat api with openai, anthropic and google chat models

- Get usage metadata from OpenAI using stream_options
- Handle openai proxies that do not support passing usage in response

- Add new usage, end response  events returned by chat api.
  - This can be optionally consumed by clients at a later point
  - Update streaming clients to mark message as completed after new
    end response event, not after end llm response event
- Ensure usage data from final response generation step is included
  - Pass usage data after llm response complete. This allows gathering
    token usage and cost for the final response generation step across
    streaming and non-streaming modes
This commit is contained in:
Debanjum 2024-11-18 18:23:05 -08:00
parent 7bdc9590dd
commit c53c3db96b
10 changed files with 139 additions and 38 deletions

View file

@ -945,7 +945,7 @@ export class KhojChatView extends KhojPaneView {
console.log("Started streaming", new Date());
} else if (chunk.type === 'end_llm_response') {
console.log("Stopped streaming", new Date());
} else if (chunk.type === 'end_response') {
// Automatically respond with voice if the subscribed user has sent voice message
if (this.chatMessageState.isVoice && this.setting.userInfo?.is_active)
this.textToSpeech(this.chatMessageState.rawResponse);

View file

@ -133,7 +133,7 @@ export function processMessageChunk(
console.log(`Started streaming: ${new Date()}`);
} else if (chunk.type === "end_llm_response") {
console.log(`Completed streaming: ${new Date()}`);
} else if (chunk.type === "end_response") {
// Append any references after all the data has been streamed
if (codeContext) currentMessage.codeContext = codeContext;
if (onlineContext) currentMessage.onlineContext = onlineContext;

View file

@ -18,7 +18,7 @@ from khoj.processor.conversation.utils import (
get_image_from_url,
)
from khoj.utils import state
from khoj.utils.helpers import in_debug_mode, is_none_or_empty
from khoj.utils.helpers import get_chat_usage_metrics, in_debug_mode, is_none_or_empty
logger = logging.getLogger(__name__)
@ -59,6 +59,7 @@ def anthropic_completion_with_backoff(
aggregated_response = "{" if response_type == "json_object" else ""
max_tokens = max_tokens or DEFAULT_MAX_TOKENS_ANTHROPIC
final_message = None
model_kwargs = model_kwargs or dict()
if system_prompt:
model_kwargs["system"] = system_prompt
@ -73,6 +74,12 @@ def anthropic_completion_with_backoff(
) as stream:
for text in stream.text_stream:
aggregated_response += text
final_message = stream.get_final_message()
# Calculate cost of chat
input_tokens = final_message.usage.input_tokens
output_tokens = final_message.usage.output_tokens
tracer["usage"] = get_chat_usage_metrics(model_name, input_tokens, output_tokens, tracer.get("usage"))
# Save conversation trace
tracer["chat_model"] = model_name
@ -126,6 +133,7 @@ def anthropic_llm_thread(
]
aggregated_response = ""
final_message = None
with client.messages.stream(
messages=formatted_messages,
model=model_name, # type: ignore
@ -138,6 +146,12 @@ def anthropic_llm_thread(
for text in stream.text_stream:
aggregated_response += text
g.send(text)
final_message = stream.get_final_message()
# Calculate cost of chat
input_tokens = final_message.usage.input_tokens
output_tokens = final_message.usage.output_tokens
tracer["usage"] = get_chat_usage_metrics(model_name, input_tokens, output_tokens, tracer.get("usage"))
# Save conversation trace
tracer["chat_model"] = model_name

View file

@ -25,7 +25,7 @@ from khoj.processor.conversation.utils import (
get_image_from_url,
)
from khoj.utils import state
from khoj.utils.helpers import in_debug_mode, is_none_or_empty
from khoj.utils.helpers import get_chat_usage_metrics, in_debug_mode, is_none_or_empty
logger = logging.getLogger(__name__)
@ -68,6 +68,7 @@ def gemini_completion_with_backoff(
response = chat_session.send_message(formatted_messages[-1]["parts"])
response_text = response.text
except StopCandidateException as e:
response = None
response_text, _ = handle_gemini_response(e.args)
# Respond with reason for stopping
logger.warning(
@ -75,6 +76,11 @@ def gemini_completion_with_backoff(
+ f"Last Message by {messages[-1].role}: {messages[-1].content}"
)
# Aggregate cost of chat
input_tokens = response.usage_metadata.prompt_token_count if response else 0
output_tokens = response.usage_metadata.candidates_token_count if response else 0
tracer["usage"] = get_chat_usage_metrics(model_name, input_tokens, output_tokens, tracer.get("usage"))
# Save conversation trace
tracer["chat_model"] = model_name
tracer["temperature"] = temperature
@ -146,6 +152,11 @@ def gemini_llm_thread(
if stopped:
raise StopCandidateException(message)
# Calculate cost of chat
input_tokens = chunk.usage_metadata.prompt_token_count
output_tokens = chunk.usage_metadata.candidates_token_count
tracer["usage"] = get_chat_usage_metrics(model_name, input_tokens, output_tokens, tracer.get("usage"))
# Save conversation trace
tracer["chat_model"] = model_name
tracer["temperature"] = temperature

View file

@ -4,6 +4,8 @@ from threading import Thread
from typing import Dict
import openai
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from tenacity import (
before_sleep_log,
retry,
@ -18,7 +20,7 @@ from khoj.processor.conversation.utils import (
commit_conversation_trace,
)
from khoj.utils import state
from khoj.utils.helpers import in_debug_mode
from khoj.utils.helpers import get_chat_usage_metrics, in_debug_mode
logger = logging.getLogger(__name__)
@ -64,27 +66,34 @@ def completion_with_backoff(
if os.getenv("KHOJ_LLM_SEED"):
model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED"))
chat = client.chat.completions.create(
stream=stream,
chat: ChatCompletion | openai.Stream[ChatCompletionChunk] = client.chat.completions.create(
messages=formatted_messages, # type: ignore
model=model, # type: ignore
stream=stream,
stream_options={"include_usage": True} if stream else {},
temperature=temperature,
timeout=20,
**(model_kwargs or dict()),
)
if not stream:
return chat.choices[0].message.content
aggregated_response = ""
for chunk in chat:
if len(chunk.choices) == 0:
continue
delta_chunk = chunk.choices[0].delta # type: ignore
if isinstance(delta_chunk, str):
aggregated_response += delta_chunk
elif delta_chunk.content:
aggregated_response += delta_chunk.content
if not stream:
chunk = chat
aggregated_response = chunk.choices[0].message.content
else:
for chunk in chat:
if len(chunk.choices) == 0:
continue
delta_chunk = chunk.choices[0].delta # type: ignore
if isinstance(delta_chunk, str):
aggregated_response += delta_chunk
elif delta_chunk.content:
aggregated_response += delta_chunk.content
# Calculate cost of chat
input_tokens = chunk.usage.prompt_tokens if hasattr(chunk, "usage") and chunk.usage else 0
output_tokens = chunk.usage.completion_tokens if hasattr(chunk, "usage") and chunk.usage else 0
tracer["usage"] = get_chat_usage_metrics(model, input_tokens, output_tokens, tracer.get("usage"))
# Save conversation trace
tracer["chat_model"] = model
@ -164,10 +173,11 @@ def llm_thread(
if os.getenv("KHOJ_LLM_SEED"):
model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED"))
chat = client.chat.completions.create(
stream=stream,
chat: ChatCompletion | openai.Stream[ChatCompletionChunk] = client.chat.completions.create(
messages=formatted_messages,
model=model_name, # type: ignore
stream=stream,
stream_options={"include_usage": True} if stream else {},
temperature=temperature,
timeout=20,
**(model_kwargs or dict()),
@ -175,7 +185,8 @@ def llm_thread(
aggregated_response = ""
if not stream:
aggregated_response = chat.choices[0].message.content
chunk = chat
aggregated_response = chunk.choices[0].message.content
g.send(aggregated_response)
else:
for chunk in chat:
@ -191,6 +202,11 @@ def llm_thread(
aggregated_response += text_chunk
g.send(text_chunk)
# Calculate cost of chat
input_tokens = chunk.usage.prompt_tokens if hasattr(chunk, "usage") and chunk.usage else 0
output_tokens = chunk.usage.completion_tokens if hasattr(chunk, "usage") and chunk.usage else 0
tracer["usage"] = get_chat_usage_metrics(model_name, input_tokens, output_tokens, tracer.get("usage"))
# Save conversation trace
tracer["chat_model"] = model_name
tracer["temperature"] = temperature

View file

@ -5,7 +5,6 @@ import math
import mimetypes
import os
import queue
import re
import uuid
from dataclasses import dataclass
from datetime import datetime
@ -57,7 +56,7 @@ model_to_prompt_size = {
"gemini-1.5-flash": 20000,
"gemini-1.5-pro": 20000,
# Anthropic Models
"claude-3-5-sonnet-20240620": 20000,
"claude-3-5-sonnet-20241022": 20000,
"claude-3-5-haiku-20241022": 20000,
# Offline Models
"bartowski/Meta-Llama-3.1-8B-Instruct-GGUF": 20000,
@ -213,6 +212,8 @@ class ChatEvent(Enum):
REFERENCES = "references"
STATUS = "status"
METADATA = "metadata"
USAGE = "usage"
END_RESPONSE = "end_response"
def message_to_log(

View file

@ -667,27 +667,37 @@ async def chat(
finally:
yield event_delimiter
async def send_llm_response(response: str):
async def send_llm_response(response: str, usage: dict = None):
# Send Chat Response
async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""):
yield result
async for result in send_event(ChatEvent.MESSAGE, response):
yield result
async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""):
yield result
# Send Usage Metadata once llm interactions are complete
if usage:
async for event in send_event(ChatEvent.USAGE, usage):
yield event
async for result in send_event(ChatEvent.END_RESPONSE, ""):
yield result
def collect_telemetry():
# Gather chat response telemetry
nonlocal chat_metadata
latency = time.perf_counter() - start_time
cmd_set = set([cmd.value for cmd in conversation_commands])
cost = (tracer.get("usage", {}) or {}).get("cost", 0)
chat_metadata = chat_metadata or {}
chat_metadata["conversation_command"] = cmd_set
chat_metadata["agent"] = conversation.agent.slug if conversation.agent else None
chat_metadata["latency"] = f"{latency:.3f}"
chat_metadata["ttft_latency"] = f"{ttft:.3f}"
chat_metadata["usage"] = tracer.get("usage")
logger.info(f"Chat response time to first token: {ttft:.3f} seconds")
logger.info(f"Chat response total time: {latency:.3f} seconds")
logger.info(f"Chat response cost: ${cost:.5f}")
update_telemetry_state(
request=request,
telemetry_type="api",
@ -699,7 +709,7 @@ async def chat(
)
if is_query_empty(q):
async for result in send_llm_response("Please ask your query to get started."):
async for result in send_llm_response("Please ask your query to get started.", tracer.get("usage")):
yield result
return
@ -713,7 +723,7 @@ async def chat(
create_new=body.create_new,
)
if not conversation:
async for result in send_llm_response(f"Conversation {conversation_id} not found"):
async for result in send_llm_response(f"Conversation {conversation_id} not found", tracer.get("usage")):
yield result
return
conversation_id = conversation.id
@ -777,7 +787,7 @@ async def chat(
await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd)
q = q.replace(f"/{cmd.value}", "").strip()
except HTTPException as e:
async for result in send_llm_response(str(e.detail)):
async for result in send_llm_response(str(e.detail), tracer.get("usage")):
yield result
return
@ -834,7 +844,7 @@ async def chat(
agent_has_entries = await EntryAdapters.aagent_has_entries(agent)
if len(file_filters) == 0 and not agent_has_entries:
response_log = "No files selected for summarization. Please add files using the section on the left."
async for result in send_llm_response(response_log):
async for result in send_llm_response(response_log, tracer.get("usage")):
yield result
else:
async for response in generate_summary_from_files(
@ -853,7 +863,7 @@ async def chat(
else:
if isinstance(response, str):
response_log = response
async for result in send_llm_response(response):
async for result in send_llm_response(response, tracer.get("usage")):
yield result
await sync_to_async(save_to_conversation_log)(
@ -880,7 +890,7 @@ async def chat(
conversation_config = await ConversationAdapters.aget_default_conversation_config(user)
model_type = conversation_config.model_type
formatted_help = help_message.format(model=model_type, version=state.khoj_version, device=get_device())
async for result in send_llm_response(formatted_help):
async for result in send_llm_response(formatted_help, tracer.get("usage")):
yield result
return
# Adding specification to search online specifically on khoj.dev pages.
@ -895,7 +905,7 @@ async def chat(
except Exception as e:
logger.error(f"Error scheduling task {q} for {user.email}: {e}")
error_message = f"Unable to create automation. Ensure the automation doesn't already exist."
async for result in send_llm_response(error_message):
async for result in send_llm_response(error_message, tracer.get("usage")):
yield result
return
@ -916,7 +926,7 @@ async def chat(
raw_query_files=raw_query_files,
tracer=tracer,
)
async for result in send_llm_response(llm_response):
async for result in send_llm_response(llm_response, tracer.get("usage")):
yield result
return
@ -963,7 +973,7 @@ async def chat(
yield result
if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user):
async for result in send_llm_response(f"{no_entries_found.format()}"):
async for result in send_llm_response(f"{no_entries_found.format()}", tracer.get("usage")):
yield result
return
@ -1105,7 +1115,7 @@ async def chat(
"detail": improved_image_prompt,
"image": None,
}
async for result in send_llm_response(json.dumps(content_obj)):
async for result in send_llm_response(json.dumps(content_obj), tracer.get("usage")):
yield result
return
@ -1132,7 +1142,7 @@ async def chat(
"inferredQueries": [improved_image_prompt],
"image": generated_image,
}
async for result in send_llm_response(json.dumps(content_obj)):
async for result in send_llm_response(json.dumps(content_obj), tracer.get("usage")):
yield result
return
@ -1166,7 +1176,7 @@ async def chat(
diagram_description = excalidraw_diagram_description
else:
error_message = "Failed to generate diagram. Please try again later."
async for result in send_llm_response(error_message):
async for result in send_llm_response(error_message, tracer.get("usage")):
yield result
await sync_to_async(save_to_conversation_log)(
@ -1213,7 +1223,7 @@ async def chat(
tracer=tracer,
)
async for result in send_llm_response(json.dumps(content_obj)):
async for result in send_llm_response(json.dumps(content_obj), tracer.get("usage")):
yield result
return
@ -1252,6 +1262,11 @@ async def chat(
if item is None:
async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""):
yield result
# Send Usage Metadata once llm interactions are complete
async for event in send_event(ChatEvent.USAGE, tracer.get("usage")):
yield event
async for result in send_event(ChatEvent.END_RESPONSE, ""):
yield result
logger.debug("Finished streaming response")
return
if not connection_alive or not continue_stream:

View file

@ -1770,6 +1770,7 @@ Manage your automations [here](/automations).
class MessageProcessor:
def __init__(self):
self.references = {}
self.usage = {}
self.raw_response = ""
def convert_message_chunk_to_json(self, raw_chunk: str) -> Dict[str, Any]:
@ -1793,6 +1794,8 @@ class MessageProcessor:
chunk_type = ChatEvent(chunk["type"])
if chunk_type == ChatEvent.REFERENCES:
self.references = chunk["data"]
elif chunk_type == ChatEvent.USAGE:
self.usage = chunk["data"]
elif chunk_type == ChatEvent.MESSAGE:
chunk_data = chunk["data"]
if isinstance(chunk_data, dict):
@ -1837,7 +1840,7 @@ async def read_chat_stream(response_iterator: AsyncGenerator[str, None]) -> Dict
if buffer:
processor.process_message_chunk(buffer)
return {"response": processor.raw_response, "references": processor.references}
return {"response": processor.raw_response, "references": processor.references, "usage": processor.usage}
def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False):

View file

@ -1,4 +1,5 @@
from pathlib import Path
from typing import Dict
app_root_directory = Path(__file__).parent.parent.parent
web_directory = app_root_directory / "khoj/interface/web/"
@ -31,3 +32,19 @@ default_config = {
"image": {"encoder": "sentence-transformers/clip-ViT-B-32", "model_directory": "~/.khoj/search/image/"},
},
}
model_to_cost: Dict[str, Dict[str, float]] = {
# OpenAI Pricing: https://openai.com/api/pricing/
"gpt-4o": {"input": 2.50, "output": 10.00},
"gpt-4o-mini": {"input": 0.15, "output": 0.60},
"o1-preview": {"input": 15.0, "output": 60.00},
"o1-mini": {"input": 3.0, "output": 12.0},
# Gemini Pricing: https://ai.google.dev/pricing
"gemini-1.5-flash": {"input": 0.075, "output": 0.30},
"gemini-1.5-flash-002": {"input": 0.075, "output": 0.30},
"gemini-1.5-pro": {"input": 1.25, "output": 5.00},
"gemini-1.5-pro-002": {"input": 1.25, "output": 5.00},
# Anthropic Pricing: https://www.anthropic.com/pricing#anthropic-api_
"claude-3-5-sonnet-20241022": {"input": 3.0, "output": 15.0},
"claude-3-5-haiku-20241022": {"input": 1.0, "output": 5.0},
}

View file

@ -540,3 +540,27 @@ def get_country_code_from_timezone(tz: str) -> str:
def get_country_name_from_timezone(tz: str) -> str:
"""Get country name from timezone"""
return country_names.get(get_country_code_from_timezone(tz), "United States")
def get_cost_of_chat_message(model_name: str, input_tokens: int = 0, output_tokens: int = 0, prev_cost: float = 0.0):
"""
Calculate cost of chat message based on input and output tokens
"""
# Calculate cost of input and output tokens. Costs are per million tokens
input_cost = constants.model_to_cost.get(model_name, {}).get("input", 0) * (input_tokens / 1e6)
output_cost = constants.model_to_cost.get(model_name, {}).get("output", 0) * (output_tokens / 1e6)
return input_cost + output_cost + prev_cost
def get_chat_usage_metrics(model_name: str, input_tokens: int = 0, output_tokens: int = 0, usage: dict = {}):
"""
Get usage metrics for chat message based on input and output tokens
"""
prev_usage = usage or {"input_tokens": 0, "output_tokens": 0, "cost": 0.0}
return {
"input_tokens": prev_usage["input_tokens"] + input_tokens,
"output_tokens": prev_usage["output_tokens"] + output_tokens,
"cost": get_cost_of_chat_message(model_name, input_tokens, output_tokens, prev_cost=prev_usage["cost"]),
}