Track running costs & accuracy of eval runs in progress

Collect, display and store running costs & accuracy of eval run.

This provides more insight into eval runs during execution instead of
having to wait until the eval run completes.
This commit is contained in:
Debanjum 2024-11-19 01:06:02 -08:00
parent c53c3db96b
commit ed364fa90e

View file

@ -6,13 +6,15 @@ import os
import time
from datetime import datetime
from io import StringIO
from textwrap import dedent
from threading import Lock
from typing import Any, Dict
import pandas as pd
import requests
from datasets import Dataset, load_dataset
from khoj.utils.helpers import is_none_or_empty, timer
from khoj.utils.helpers import get_cost_of_chat_message, is_none_or_empty, timer
# Configure root logger
logging.basicConfig(level=logging.INFO, format="%(message)s")
@ -38,6 +40,28 @@ BATCH_SIZE = int(
SLEEP_SECONDS = 3 if KHOJ_MODE == "general" else 1 # Sleep between API calls to avoid rate limiting
class Counter:
"""Thread-safe counter for tracking metrics"""
def __init__(self, value=0.0):
self.value = value
self.lock = Lock()
def add(self, amount):
with self.lock:
self.value += amount
def get(self):
with self.lock:
return self.value
# Track running metrics while evaluating
running_cost = Counter()
running_true_count = Counter(0)
running_false_count = Counter(0)
def load_frames_dataset():
"""
Load the Google FRAMES benchmark dataset from HuggingFace
@ -104,25 +128,31 @@ def load_simpleqa_dataset():
return None
def get_agent_response(prompt: str) -> str:
def get_agent_response(prompt: str) -> Dict[str, Any]:
"""Get response from the Khoj API"""
# Set headers
headers = {"Content-Type": "application/json"}
if not is_none_or_empty(KHOJ_API_KEY):
headers["Authorization"] = f"Bearer {KHOJ_API_KEY}"
try:
response = requests.post(
KHOJ_CHAT_API_URL,
headers={"Content-Type": "application/json", "Authorization": f"Bearer {KHOJ_API_KEY}"},
headers=headers,
json={
"q": prompt,
"create_new": True,
},
)
response.raise_for_status()
return response.json().get("response", "")
response_json = response.json()
return {"response": response_json.get("response", ""), "usage": response_json.get("usage", {})}
except Exception as e:
logger.error(f"Error getting agent response: {e}")
return ""
return {"response": "", "usage": {}}
def evaluate_response(query: str, agent_response: str, ground_truth: str) -> Dict[str, Any]:
def evaluate_response(query: str, agent_response: str, ground_truth: str) -> tuple[bool | None, str, float]:
"""Evaluate Khoj response against benchmark ground truth using Gemini"""
evaluation_prompt = f"""
Compare the following agent response with the ground truth answer.
@ -147,10 +177,16 @@ def evaluate_response(query: str, agent_response: str, ground_truth: str) -> Dic
},
)
response.raise_for_status()
response_json = response.json()
# Update cost of evaluation
input_tokens = response_json["usageMetadata"]["promptTokenCount"]
ouput_tokens = response_json["usageMetadata"]["candidatesTokenCount"]
cost = get_cost_of_chat_message(GEMINI_EVAL_MODEL, input_tokens, ouput_tokens)
# Parse evaluation response
eval_response: dict[str, str] = json.loads(
clean_json(response.json()["candidates"][0]["content"]["parts"][0]["text"])
clean_json(response_json["candidates"][0]["content"]["parts"][0]["text"])
)
decision = str(eval_response.get("decision", "")).upper() == "TRUE"
explanation = eval_response.get("explanation", "")
@ -158,13 +194,14 @@ def evaluate_response(query: str, agent_response: str, ground_truth: str) -> Dic
if "503 Service Error" in explanation:
decision = None
# Extract decision and explanation from structured response
return decision, explanation
return decision, explanation, cost
except Exception as e:
logger.error(f"Error in evaluation: {e}")
return None, f"Evaluation failed: {str(e)}"
return None, f"Evaluation failed: {str(e)}", 0.0
def process_batch(batch, batch_start, results, dataset_length):
global running_cost
for idx, (prompt, answer, reasoning_type) in enumerate(batch):
current_index = batch_start + idx
logger.info(f"Processing example: {current_index}/{dataset_length}")
@ -173,14 +210,16 @@ def process_batch(batch, batch_start, results, dataset_length):
prompt = f"/{KHOJ_MODE} {prompt}" if KHOJ_MODE and not prompt.startswith(f"/{KHOJ_MODE}") else prompt
# Get agent response
agent_response = get_agent_response(prompt)
response = get_agent_response(prompt)
agent_response = response["response"]
agent_usage = response["usage"]
# Evaluate response
if is_none_or_empty(agent_response):
decision = None
explanation = "Agent response is empty. This maybe due to a service error."
else:
decision, explanation = evaluate_response(prompt, agent_response, answer)
decision, explanation, eval_cost = evaluate_response(prompt, agent_response, answer)
# Store results
results.append(
@ -192,17 +231,38 @@ def process_batch(batch, batch_start, results, dataset_length):
"evaluation_decision": decision,
"evaluation_explanation": explanation,
"reasoning_type": reasoning_type,
"usage": agent_usage,
}
)
# Log results
# Update running cost
query_cost = float(agent_usage.get("cost", 0.0))
running_cost.add(query_cost + eval_cost)
# Update running accuracy
running_accuracy = 0.0
if decision is not None:
running_true_count.add(1) if decision == True else running_false_count.add(1)
running_accuracy = running_true_count.get() / (running_true_count.get() + running_false_count.get())
## Log results
decision_color = {True: "green", None: "blue", False: "red"}[decision]
colored_decision = color_text(str(decision), decision_color)
logger.info(
f"Decision: {colored_decision}\nQuestion: {prompt}\nExpected Answer: {answer}\nAgent Answer: {agent_response}\nExplanation: {explanation}\n"
)
result_to_print = f"""
---------
Decision: {colored_decision}
Accuracy: {running_accuracy:.2%}
Question: {prompt}
Expected Answer: {answer}
Agent Answer: {agent_response}
Explanation: {explanation}
Cost: ${running_cost.get():.5f} (Query: ${query_cost:.5f}, Eval: ${eval_cost:.5f})
---------
"""
logger.info(dedent(result_to_print).lstrip())
time.sleep(SLEEP_SECONDS) # Rate limiting
# Sleep between API calls to avoid rate limiting
time.sleep(SLEEP_SECONDS)
def color_text(text, color):
@ -281,17 +341,18 @@ def main():
lambda x: (x == True).mean()
)
# Print summary
# Collect summary
colored_accuracy = color_text(f"{accuracy:.2%}", "blue")
logger.info(f"\nOverall Accuracy: {colored_accuracy}")
logger.info(f"\nAccuracy by Reasoning Type:\n{reasoning_type_accuracy}")
# Save summary to file
colored_accuracy_str = f"Overall Accuracy: {colored_accuracy} on {args.dataset.title()} dataset."
accuracy_str = f"Overall Accuracy: {accuracy:.2%} on {args.dataset}."
accuracy_by_reasoning = f"Accuracy by Reasoning Type:\n{reasoning_type_accuracy}"
cost = f"Total Cost: ${running_cost.get():.5f}."
sample_type = f"Sampling Type: {SAMPLE_SIZE} samples." if SAMPLE_SIZE else "Whole dataset."
sample_type += " Randomized." if RANDOMIZE else ""
summary = (
f"Overall Accuracy: {accuracy:.2%}\n\nAccuracy by Reasoning Type:\n{reasoning_type_accuracy}\n\n{sample_type}\n"
)
logger.info(f"\n{colored_accuracy_str}\n\n{accuracy_by_reasoning}\n\n{cost}\n\n{sample_type}\n")
# Save summary to file
summary = f"{accuracy_str}\n\n{accuracy_by_reasoning}\n\n{cost}\n\n{sample_type}\n"
summary_file = args.output.replace(".csv", ".txt") if args.output else None
summary_file = (
summary_file or f"{args.dataset}_evaluation_summary_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.txt"