Track Usage Metrics in Chat API. Track Running Cost, Accuracy in Evals (#985)

- Track, return cost and usage metrics in chat api response
  Track input, output token usage and cost of interactions with 
  openai, anthropic and google chat models for each call to the khoj chat api
- Collect, display and store costs & accuracy of eval run currently in progress
  This provides more insight into eval runs during execution 
  instead of having to wait until the eval run completes.
This commit is contained in:
Debanjum 2024-11-20 12:59:44 -08:00 committed by GitHub
commit 6f1adcfe67
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 230 additions and 67 deletions

View file

@ -1,9 +1,10 @@
name: Run Khoj Evals
name: eval
on:
# Run on every releases
release:
types: [published]
# Run on every release
push:
tags:
- "*"
# Allow manual triggers from GitHub UI
workflow_dispatch:
inputs:
@ -82,7 +83,7 @@ jobs:
sed -i 's/dynamic = \["version"\]/version = "${{ steps.hatch.outputs.version }}"/' pyproject.toml
pip install --upgrade .[dev]
- name: 📝 Run Evals
- name: 📝 Run Eval
env:
KHOJ_MODE: ${{ matrix.khoj_mode }}
SAMPLE_SIZE: ${{ inputs.sample_size }}

View file

@ -945,7 +945,7 @@ export class KhojChatView extends KhojPaneView {
console.log("Started streaming", new Date());
} else if (chunk.type === 'end_llm_response') {
console.log("Stopped streaming", new Date());
} else if (chunk.type === 'end_response') {
// Automatically respond with voice if the subscribed user has sent voice message
if (this.chatMessageState.isVoice && this.setting.userInfo?.is_active)
this.textToSpeech(this.chatMessageState.rawResponse);

View file

@ -133,7 +133,7 @@ export function processMessageChunk(
console.log(`Started streaming: ${new Date()}`);
} else if (chunk.type === "end_llm_response") {
console.log(`Completed streaming: ${new Date()}`);
} else if (chunk.type === "end_response") {
// Append any references after all the data has been streamed
if (codeContext) currentMessage.codeContext = codeContext;
if (onlineContext) currentMessage.onlineContext = onlineContext;

View file

@ -18,7 +18,7 @@ from khoj.processor.conversation.utils import (
get_image_from_url,
)
from khoj.utils import state
from khoj.utils.helpers import in_debug_mode, is_none_or_empty
from khoj.utils.helpers import get_chat_usage_metrics, in_debug_mode, is_none_or_empty
logger = logging.getLogger(__name__)
@ -59,6 +59,7 @@ def anthropic_completion_with_backoff(
aggregated_response = "{" if response_type == "json_object" else ""
max_tokens = max_tokens or DEFAULT_MAX_TOKENS_ANTHROPIC
final_message = None
model_kwargs = model_kwargs or dict()
if system_prompt:
model_kwargs["system"] = system_prompt
@ -73,6 +74,12 @@ def anthropic_completion_with_backoff(
) as stream:
for text in stream.text_stream:
aggregated_response += text
final_message = stream.get_final_message()
# Calculate cost of chat
input_tokens = final_message.usage.input_tokens
output_tokens = final_message.usage.output_tokens
tracer["usage"] = get_chat_usage_metrics(model_name, input_tokens, output_tokens, tracer.get("usage"))
# Save conversation trace
tracer["chat_model"] = model_name
@ -126,6 +133,7 @@ def anthropic_llm_thread(
]
aggregated_response = ""
final_message = None
with client.messages.stream(
messages=formatted_messages,
model=model_name, # type: ignore
@ -138,6 +146,12 @@ def anthropic_llm_thread(
for text in stream.text_stream:
aggregated_response += text
g.send(text)
final_message = stream.get_final_message()
# Calculate cost of chat
input_tokens = final_message.usage.input_tokens
output_tokens = final_message.usage.output_tokens
tracer["usage"] = get_chat_usage_metrics(model_name, input_tokens, output_tokens, tracer.get("usage"))
# Save conversation trace
tracer["chat_model"] = model_name

View file

@ -25,7 +25,7 @@ from khoj.processor.conversation.utils import (
get_image_from_url,
)
from khoj.utils import state
from khoj.utils.helpers import in_debug_mode, is_none_or_empty
from khoj.utils.helpers import get_chat_usage_metrics, in_debug_mode, is_none_or_empty
logger = logging.getLogger(__name__)
@ -68,6 +68,7 @@ def gemini_completion_with_backoff(
response = chat_session.send_message(formatted_messages[-1]["parts"])
response_text = response.text
except StopCandidateException as e:
response = None
response_text, _ = handle_gemini_response(e.args)
# Respond with reason for stopping
logger.warning(
@ -75,6 +76,11 @@ def gemini_completion_with_backoff(
+ f"Last Message by {messages[-1].role}: {messages[-1].content}"
)
# Aggregate cost of chat
input_tokens = response.usage_metadata.prompt_token_count if response else 0
output_tokens = response.usage_metadata.candidates_token_count if response else 0
tracer["usage"] = get_chat_usage_metrics(model_name, input_tokens, output_tokens, tracer.get("usage"))
# Save conversation trace
tracer["chat_model"] = model_name
tracer["temperature"] = temperature
@ -146,6 +152,11 @@ def gemini_llm_thread(
if stopped:
raise StopCandidateException(message)
# Calculate cost of chat
input_tokens = chunk.usage_metadata.prompt_token_count
output_tokens = chunk.usage_metadata.candidates_token_count
tracer["usage"] = get_chat_usage_metrics(model_name, input_tokens, output_tokens, tracer.get("usage"))
# Save conversation trace
tracer["chat_model"] = model_name
tracer["temperature"] = temperature

View file

@ -4,6 +4,8 @@ from threading import Thread
from typing import Dict
import openai
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from tenacity import (
before_sleep_log,
retry,
@ -18,7 +20,7 @@ from khoj.processor.conversation.utils import (
commit_conversation_trace,
)
from khoj.utils import state
from khoj.utils.helpers import in_debug_mode
from khoj.utils.helpers import get_chat_usage_metrics, in_debug_mode
logger = logging.getLogger(__name__)
@ -63,27 +65,34 @@ def completion_with_backoff(
if os.getenv("KHOJ_LLM_SEED"):
model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED"))
chat = client.chat.completions.create(
stream=stream,
chat: ChatCompletion | openai.Stream[ChatCompletionChunk] = client.chat.completions.create(
messages=formatted_messages, # type: ignore
model=model, # type: ignore
stream=stream,
stream_options={"include_usage": True} if stream else {},
temperature=temperature,
timeout=20,
**(model_kwargs or dict()),
)
if not stream:
return chat.choices[0].message.content
aggregated_response = ""
for chunk in chat:
if len(chunk.choices) == 0:
continue
delta_chunk = chunk.choices[0].delta # type: ignore
if isinstance(delta_chunk, str):
aggregated_response += delta_chunk
elif delta_chunk.content:
aggregated_response += delta_chunk.content
if not stream:
chunk = chat
aggregated_response = chunk.choices[0].message.content
else:
for chunk in chat:
if len(chunk.choices) == 0:
continue
delta_chunk = chunk.choices[0].delta # type: ignore
if isinstance(delta_chunk, str):
aggregated_response += delta_chunk
elif delta_chunk.content:
aggregated_response += delta_chunk.content
# Calculate cost of chat
input_tokens = chunk.usage.prompt_tokens if hasattr(chunk, "usage") and chunk.usage else 0
output_tokens = chunk.usage.completion_tokens if hasattr(chunk, "usage") and chunk.usage else 0
tracer["usage"] = get_chat_usage_metrics(model, input_tokens, output_tokens, tracer.get("usage"))
# Save conversation trace
tracer["chat_model"] = model
@ -162,10 +171,11 @@ def llm_thread(
if os.getenv("KHOJ_LLM_SEED"):
model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED"))
chat = client.chat.completions.create(
stream=stream,
chat: ChatCompletion | openai.Stream[ChatCompletionChunk] = client.chat.completions.create(
messages=formatted_messages,
model=model_name, # type: ignore
stream=stream,
stream_options={"include_usage": True} if stream else {},
temperature=temperature,
timeout=20,
**(model_kwargs or dict()),
@ -173,7 +183,8 @@ def llm_thread(
aggregated_response = ""
if not stream:
aggregated_response = chat.choices[0].message.content
chunk = chat
aggregated_response = chunk.choices[0].message.content
g.send(aggregated_response)
else:
for chunk in chat:
@ -189,6 +200,11 @@ def llm_thread(
aggregated_response += text_chunk
g.send(text_chunk)
# Calculate cost of chat
input_tokens = chunk.usage.prompt_tokens if hasattr(chunk, "usage") and chunk.usage else 0
output_tokens = chunk.usage.completion_tokens if hasattr(chunk, "usage") and chunk.usage else 0
tracer["usage"] = get_chat_usage_metrics(model_name, input_tokens, output_tokens, tracer.get("usage"))
# Save conversation trace
tracer["chat_model"] = model_name
tracer["temperature"] = temperature

View file

@ -5,7 +5,6 @@ import math
import mimetypes
import os
import queue
import re
import uuid
from dataclasses import dataclass
from datetime import datetime
@ -57,7 +56,7 @@ model_to_prompt_size = {
"gemini-1.5-flash": 20000,
"gemini-1.5-pro": 20000,
# Anthropic Models
"claude-3-5-sonnet-20240620": 20000,
"claude-3-5-sonnet-20241022": 20000,
"claude-3-5-haiku-20241022": 20000,
# Offline Models
"bartowski/Meta-Llama-3.1-8B-Instruct-GGUF": 20000,
@ -213,6 +212,8 @@ class ChatEvent(Enum):
REFERENCES = "references"
STATUS = "status"
METADATA = "metadata"
USAGE = "usage"
END_RESPONSE = "end_response"
def message_to_log(

View file

@ -667,27 +667,37 @@ async def chat(
finally:
yield event_delimiter
async def send_llm_response(response: str):
async def send_llm_response(response: str, usage: dict = None):
# Send Chat Response
async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""):
yield result
async for result in send_event(ChatEvent.MESSAGE, response):
yield result
async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""):
yield result
# Send Usage Metadata once llm interactions are complete
if usage:
async for event in send_event(ChatEvent.USAGE, usage):
yield event
async for result in send_event(ChatEvent.END_RESPONSE, ""):
yield result
def collect_telemetry():
# Gather chat response telemetry
nonlocal chat_metadata
latency = time.perf_counter() - start_time
cmd_set = set([cmd.value for cmd in conversation_commands])
cost = (tracer.get("usage", {}) or {}).get("cost", 0)
chat_metadata = chat_metadata or {}
chat_metadata["conversation_command"] = cmd_set
chat_metadata["agent"] = conversation.agent.slug if conversation and conversation.agent else None
chat_metadata["latency"] = f"{latency:.3f}"
chat_metadata["ttft_latency"] = f"{ttft:.3f}"
chat_metadata["usage"] = tracer.get("usage")
logger.info(f"Chat response time to first token: {ttft:.3f} seconds")
logger.info(f"Chat response total time: {latency:.3f} seconds")
logger.info(f"Chat response cost: ${cost:.5f}")
update_telemetry_state(
request=request,
telemetry_type="api",
@ -699,7 +709,7 @@ async def chat(
)
if is_query_empty(q):
async for result in send_llm_response("Please ask your query to get started."):
async for result in send_llm_response("Please ask your query to get started.", tracer.get("usage")):
yield result
return
@ -713,7 +723,7 @@ async def chat(
create_new=body.create_new,
)
if not conversation:
async for result in send_llm_response(f"Conversation {conversation_id} not found"):
async for result in send_llm_response(f"Conversation {conversation_id} not found", tracer.get("usage")):
yield result
return
conversation_id = conversation.id
@ -777,7 +787,7 @@ async def chat(
await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd)
q = q.replace(f"/{cmd.value}", "").strip()
except HTTPException as e:
async for result in send_llm_response(str(e.detail)):
async for result in send_llm_response(str(e.detail), tracer.get("usage")):
yield result
return
@ -834,7 +844,7 @@ async def chat(
agent_has_entries = await EntryAdapters.aagent_has_entries(agent)
if len(file_filters) == 0 and not agent_has_entries:
response_log = "No files selected for summarization. Please add files using the section on the left."
async for result in send_llm_response(response_log):
async for result in send_llm_response(response_log, tracer.get("usage")):
yield result
else:
async for response in generate_summary_from_files(
@ -853,7 +863,7 @@ async def chat(
else:
if isinstance(response, str):
response_log = response
async for result in send_llm_response(response):
async for result in send_llm_response(response, tracer.get("usage")):
yield result
await sync_to_async(save_to_conversation_log)(
@ -880,7 +890,7 @@ async def chat(
conversation_config = await ConversationAdapters.aget_default_conversation_config(user)
model_type = conversation_config.model_type
formatted_help = help_message.format(model=model_type, version=state.khoj_version, device=get_device())
async for result in send_llm_response(formatted_help):
async for result in send_llm_response(formatted_help, tracer.get("usage")):
yield result
return
# Adding specification to search online specifically on khoj.dev pages.
@ -895,7 +905,7 @@ async def chat(
except Exception as e:
logger.error(f"Error scheduling task {q} for {user.email}: {e}")
error_message = f"Unable to create automation. Ensure the automation doesn't already exist."
async for result in send_llm_response(error_message):
async for result in send_llm_response(error_message, tracer.get("usage")):
yield result
return
@ -916,7 +926,7 @@ async def chat(
raw_query_files=raw_query_files,
tracer=tracer,
)
async for result in send_llm_response(llm_response):
async for result in send_llm_response(llm_response, tracer.get("usage")):
yield result
return
@ -963,7 +973,7 @@ async def chat(
yield result
if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user):
async for result in send_llm_response(f"{no_entries_found.format()}"):
async for result in send_llm_response(f"{no_entries_found.format()}", tracer.get("usage")):
yield result
return
@ -1105,7 +1115,7 @@ async def chat(
"detail": improved_image_prompt,
"image": None,
}
async for result in send_llm_response(json.dumps(content_obj)):
async for result in send_llm_response(json.dumps(content_obj), tracer.get("usage")):
yield result
return
@ -1132,7 +1142,7 @@ async def chat(
"inferredQueries": [improved_image_prompt],
"image": generated_image,
}
async for result in send_llm_response(json.dumps(content_obj)):
async for result in send_llm_response(json.dumps(content_obj), tracer.get("usage")):
yield result
return
@ -1166,7 +1176,7 @@ async def chat(
diagram_description = excalidraw_diagram_description
else:
error_message = "Failed to generate diagram. Please try again later."
async for result in send_llm_response(error_message):
async for result in send_llm_response(error_message, tracer.get("usage")):
yield result
await sync_to_async(save_to_conversation_log)(
@ -1213,7 +1223,7 @@ async def chat(
tracer=tracer,
)
async for result in send_llm_response(json.dumps(content_obj)):
async for result in send_llm_response(json.dumps(content_obj), tracer.get("usage")):
yield result
return
@ -1252,6 +1262,11 @@ async def chat(
if item is None:
async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""):
yield result
# Send Usage Metadata once llm interactions are complete
async for event in send_event(ChatEvent.USAGE, tracer.get("usage")):
yield event
async for result in send_event(ChatEvent.END_RESPONSE, ""):
yield result
logger.debug("Finished streaming response")
return
if not connection_alive or not continue_stream:

View file

@ -1770,6 +1770,7 @@ Manage your automations [here](/automations).
class MessageProcessor:
def __init__(self):
self.references = {}
self.usage = {}
self.raw_response = ""
def convert_message_chunk_to_json(self, raw_chunk: str) -> Dict[str, Any]:
@ -1793,6 +1794,8 @@ class MessageProcessor:
chunk_type = ChatEvent(chunk["type"])
if chunk_type == ChatEvent.REFERENCES:
self.references = chunk["data"]
elif chunk_type == ChatEvent.USAGE:
self.usage = chunk["data"]
elif chunk_type == ChatEvent.MESSAGE:
chunk_data = chunk["data"]
if isinstance(chunk_data, dict):
@ -1837,7 +1840,7 @@ async def read_chat_stream(response_iterator: AsyncGenerator[str, None]) -> Dict
if buffer:
processor.process_message_chunk(buffer)
return {"response": processor.raw_response, "references": processor.references}
return {"response": processor.raw_response, "references": processor.references, "usage": processor.usage}
def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False):

View file

@ -1,4 +1,5 @@
from pathlib import Path
from typing import Dict
app_root_directory = Path(__file__).parent.parent.parent
web_directory = app_root_directory / "khoj/interface/web/"
@ -31,3 +32,19 @@ default_config = {
"image": {"encoder": "sentence-transformers/clip-ViT-B-32", "model_directory": "~/.khoj/search/image/"},
},
}
model_to_cost: Dict[str, Dict[str, float]] = {
# OpenAI Pricing: https://openai.com/api/pricing/
"gpt-4o": {"input": 2.50, "output": 10.00},
"gpt-4o-mini": {"input": 0.15, "output": 0.60},
"o1-preview": {"input": 15.0, "output": 60.00},
"o1-mini": {"input": 3.0, "output": 12.0},
# Gemini Pricing: https://ai.google.dev/pricing
"gemini-1.5-flash": {"input": 0.075, "output": 0.30},
"gemini-1.5-flash-002": {"input": 0.075, "output": 0.30},
"gemini-1.5-pro": {"input": 1.25, "output": 5.00},
"gemini-1.5-pro-002": {"input": 1.25, "output": 5.00},
# Anthropic Pricing: https://www.anthropic.com/pricing#anthropic-api_
"claude-3-5-sonnet-20241022": {"input": 3.0, "output": 15.0},
"claude-3-5-haiku-20241022": {"input": 1.0, "output": 5.0},
}

View file

@ -540,3 +540,27 @@ def get_country_code_from_timezone(tz: str) -> str:
def get_country_name_from_timezone(tz: str) -> str:
"""Get country name from timezone"""
return country_names.get(get_country_code_from_timezone(tz), "United States")
def get_cost_of_chat_message(model_name: str, input_tokens: int = 0, output_tokens: int = 0, prev_cost: float = 0.0):
"""
Calculate cost of chat message based on input and output tokens
"""
# Calculate cost of input and output tokens. Costs are per million tokens
input_cost = constants.model_to_cost.get(model_name, {}).get("input", 0) * (input_tokens / 1e6)
output_cost = constants.model_to_cost.get(model_name, {}).get("output", 0) * (output_tokens / 1e6)
return input_cost + output_cost + prev_cost
def get_chat_usage_metrics(model_name: str, input_tokens: int = 0, output_tokens: int = 0, usage: dict = {}):
"""
Get usage metrics for chat message based on input and output tokens
"""
prev_usage = usage or {"input_tokens": 0, "output_tokens": 0, "cost": 0.0}
return {
"input_tokens": prev_usage["input_tokens"] + input_tokens,
"output_tokens": prev_usage["output_tokens"] + output_tokens,
"cost": get_cost_of_chat_message(model_name, input_tokens, output_tokens, prev_cost=prev_usage["cost"]),
}

View file

@ -6,13 +6,15 @@ import os
import time
from datetime import datetime
from io import StringIO
from textwrap import dedent
from threading import Lock
from typing import Any, Dict
import pandas as pd
import requests
from datasets import Dataset, load_dataset
from khoj.utils.helpers import is_none_or_empty, timer
from khoj.utils.helpers import get_cost_of_chat_message, is_none_or_empty, timer
# Configure root logger
logging.basicConfig(level=logging.INFO, format="%(message)s")
@ -38,6 +40,28 @@ BATCH_SIZE = int(
SLEEP_SECONDS = 3 if KHOJ_MODE == "general" else 1 # Sleep between API calls to avoid rate limiting
class Counter:
"""Thread-safe counter for tracking metrics"""
def __init__(self, value=0.0):
self.value = value
self.lock = Lock()
def add(self, amount):
with self.lock:
self.value += amount
def get(self):
with self.lock:
return self.value
# Track running metrics while evaluating
running_cost = Counter()
running_true_count = Counter(0)
running_false_count = Counter(0)
def load_frames_dataset():
"""
Load the Google FRAMES benchmark dataset from HuggingFace
@ -104,25 +128,31 @@ def load_simpleqa_dataset():
return None
def get_agent_response(prompt: str) -> str:
def get_agent_response(prompt: str) -> Dict[str, Any]:
"""Get response from the Khoj API"""
# Set headers
headers = {"Content-Type": "application/json"}
if not is_none_or_empty(KHOJ_API_KEY):
headers["Authorization"] = f"Bearer {KHOJ_API_KEY}"
try:
response = requests.post(
KHOJ_CHAT_API_URL,
headers={"Content-Type": "application/json", "Authorization": f"Bearer {KHOJ_API_KEY}"},
headers=headers,
json={
"q": prompt,
"create_new": True,
},
)
response.raise_for_status()
return response.json().get("response", "")
response_json = response.json()
return {"response": response_json.get("response", ""), "usage": response_json.get("usage", {})}
except Exception as e:
logger.error(f"Error getting agent response: {e}")
return ""
return {"response": "", "usage": {}}
def evaluate_response(query: str, agent_response: str, ground_truth: str) -> Dict[str, Any]:
def evaluate_response(query: str, agent_response: str, ground_truth: str) -> tuple[bool | None, str, float]:
"""Evaluate Khoj response against benchmark ground truth using Gemini"""
evaluation_prompt = f"""
Compare the following agent response with the ground truth answer.
@ -147,10 +177,16 @@ def evaluate_response(query: str, agent_response: str, ground_truth: str) -> Dic
},
)
response.raise_for_status()
response_json = response.json()
# Update cost of evaluation
input_tokens = response_json["usageMetadata"]["promptTokenCount"]
ouput_tokens = response_json["usageMetadata"]["candidatesTokenCount"]
cost = get_cost_of_chat_message(GEMINI_EVAL_MODEL, input_tokens, ouput_tokens)
# Parse evaluation response
eval_response: dict[str, str] = json.loads(
clean_json(response.json()["candidates"][0]["content"]["parts"][0]["text"])
clean_json(response_json["candidates"][0]["content"]["parts"][0]["text"])
)
decision = str(eval_response.get("decision", "")).upper() == "TRUE"
explanation = eval_response.get("explanation", "")
@ -158,13 +194,14 @@ def evaluate_response(query: str, agent_response: str, ground_truth: str) -> Dic
if "503 Service Error" in explanation:
decision = None
# Extract decision and explanation from structured response
return decision, explanation
return decision, explanation, cost
except Exception as e:
logger.error(f"Error in evaluation: {e}")
return None, f"Evaluation failed: {str(e)}"
return None, f"Evaluation failed: {str(e)}", 0.0
def process_batch(batch, batch_start, results, dataset_length):
global running_cost
for idx, (prompt, answer, reasoning_type) in enumerate(batch):
current_index = batch_start + idx
logger.info(f"Processing example: {current_index}/{dataset_length}")
@ -173,14 +210,16 @@ def process_batch(batch, batch_start, results, dataset_length):
prompt = f"/{KHOJ_MODE} {prompt}" if KHOJ_MODE and not prompt.startswith(f"/{KHOJ_MODE}") else prompt
# Get agent response
agent_response = get_agent_response(prompt)
response = get_agent_response(prompt)
agent_response = response["response"]
agent_usage = response["usage"]
# Evaluate response
if is_none_or_empty(agent_response):
decision = None
explanation = "Agent response is empty. This maybe due to a service error."
else:
decision, explanation = evaluate_response(prompt, agent_response, answer)
decision, explanation, eval_cost = evaluate_response(prompt, agent_response, answer)
# Store results
results.append(
@ -192,17 +231,38 @@ def process_batch(batch, batch_start, results, dataset_length):
"evaluation_decision": decision,
"evaluation_explanation": explanation,
"reasoning_type": reasoning_type,
"usage": agent_usage,
}
)
# Log results
# Update running cost
query_cost = float(agent_usage.get("cost", 0.0))
running_cost.add(query_cost + eval_cost)
# Update running accuracy
running_accuracy = 0.0
if decision is not None:
running_true_count.add(1) if decision == True else running_false_count.add(1)
running_accuracy = running_true_count.get() / (running_true_count.get() + running_false_count.get())
## Log results
decision_color = {True: "green", None: "blue", False: "red"}[decision]
colored_decision = color_text(str(decision), decision_color)
logger.info(
f"Decision: {colored_decision}\nQuestion: {prompt}\nExpected Answer: {answer}\nAgent Answer: {agent_response}\nExplanation: {explanation}\n"
)
result_to_print = f"""
---------
Decision: {colored_decision}
Accuracy: {running_accuracy:.2%}
Question: {prompt}
Expected Answer: {answer}
Agent Answer: {agent_response}
Explanation: {explanation}
Cost: ${running_cost.get():.5f} (Query: ${query_cost:.5f}, Eval: ${eval_cost:.5f})
---------
"""
logger.info(dedent(result_to_print).lstrip())
time.sleep(SLEEP_SECONDS) # Rate limiting
# Sleep between API calls to avoid rate limiting
time.sleep(SLEEP_SECONDS)
def color_text(text, color):
@ -281,17 +341,18 @@ def main():
lambda x: (x == True).mean()
)
# Print summary
# Collect summary
colored_accuracy = color_text(f"{accuracy:.2%}", "blue")
logger.info(f"\nOverall Accuracy: {colored_accuracy}")
logger.info(f"\nAccuracy by Reasoning Type:\n{reasoning_type_accuracy}")
# Save summary to file
colored_accuracy_str = f"Overall Accuracy: {colored_accuracy} on {args.dataset.title()} dataset."
accuracy_str = f"Overall Accuracy: {accuracy:.2%} on {args.dataset}."
accuracy_by_reasoning = f"Accuracy by Reasoning Type:\n{reasoning_type_accuracy}"
cost = f"Total Cost: ${running_cost.get():.5f}."
sample_type = f"Sampling Type: {SAMPLE_SIZE} samples." if SAMPLE_SIZE else "Whole dataset."
sample_type += " Randomized." if RANDOMIZE else ""
summary = (
f"Overall Accuracy: {accuracy:.2%}\n\nAccuracy by Reasoning Type:\n{reasoning_type_accuracy}\n\n{sample_type}\n"
)
logger.info(f"\n{colored_accuracy_str}\n\n{accuracy_by_reasoning}\n\n{cost}\n\n{sample_type}\n")
# Save summary to file
summary = f"{accuracy_str}\n\n{accuracy_by_reasoning}\n\n{cost}\n\n{sample_type}\n"
summary_file = args.output.replace(".csv", ".txt") if args.output else None
summary_file = (
summary_file or f"{args.dataset}_evaluation_summary_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.txt"