Track running costs & accuracy of eval runs in progress

Collect, display and store running costs & accuracy of eval run.

This provides more insight into eval runs during execution instead of
having to wait until the eval run completes.
This commit is contained in:
Debanjum 2024-11-19 01:06:02 -08:00
parent c53c3db96b
commit ed364fa90e

View file

@ -6,13 +6,15 @@ import os
import time import time
from datetime import datetime from datetime import datetime
from io import StringIO from io import StringIO
from textwrap import dedent
from threading import Lock
from typing import Any, Dict from typing import Any, Dict
import pandas as pd import pandas as pd
import requests import requests
from datasets import Dataset, load_dataset from datasets import Dataset, load_dataset
from khoj.utils.helpers import is_none_or_empty, timer from khoj.utils.helpers import get_cost_of_chat_message, is_none_or_empty, timer
# Configure root logger # Configure root logger
logging.basicConfig(level=logging.INFO, format="%(message)s") logging.basicConfig(level=logging.INFO, format="%(message)s")
@ -38,6 +40,28 @@ BATCH_SIZE = int(
SLEEP_SECONDS = 3 if KHOJ_MODE == "general" else 1 # Sleep between API calls to avoid rate limiting SLEEP_SECONDS = 3 if KHOJ_MODE == "general" else 1 # Sleep between API calls to avoid rate limiting
class Counter:
"""Thread-safe counter for tracking metrics"""
def __init__(self, value=0.0):
self.value = value
self.lock = Lock()
def add(self, amount):
with self.lock:
self.value += amount
def get(self):
with self.lock:
return self.value
# Track running metrics while evaluating
running_cost = Counter()
running_true_count = Counter(0)
running_false_count = Counter(0)
def load_frames_dataset(): def load_frames_dataset():
""" """
Load the Google FRAMES benchmark dataset from HuggingFace Load the Google FRAMES benchmark dataset from HuggingFace
@ -104,25 +128,31 @@ def load_simpleqa_dataset():
return None return None
def get_agent_response(prompt: str) -> str: def get_agent_response(prompt: str) -> Dict[str, Any]:
"""Get response from the Khoj API""" """Get response from the Khoj API"""
# Set headers
headers = {"Content-Type": "application/json"}
if not is_none_or_empty(KHOJ_API_KEY):
headers["Authorization"] = f"Bearer {KHOJ_API_KEY}"
try: try:
response = requests.post( response = requests.post(
KHOJ_CHAT_API_URL, KHOJ_CHAT_API_URL,
headers={"Content-Type": "application/json", "Authorization": f"Bearer {KHOJ_API_KEY}"}, headers=headers,
json={ json={
"q": prompt, "q": prompt,
"create_new": True, "create_new": True,
}, },
) )
response.raise_for_status() response.raise_for_status()
return response.json().get("response", "") response_json = response.json()
return {"response": response_json.get("response", ""), "usage": response_json.get("usage", {})}
except Exception as e: except Exception as e:
logger.error(f"Error getting agent response: {e}") logger.error(f"Error getting agent response: {e}")
return "" return {"response": "", "usage": {}}
def evaluate_response(query: str, agent_response: str, ground_truth: str) -> Dict[str, Any]: def evaluate_response(query: str, agent_response: str, ground_truth: str) -> tuple[bool | None, str, float]:
"""Evaluate Khoj response against benchmark ground truth using Gemini""" """Evaluate Khoj response against benchmark ground truth using Gemini"""
evaluation_prompt = f""" evaluation_prompt = f"""
Compare the following agent response with the ground truth answer. Compare the following agent response with the ground truth answer.
@ -147,10 +177,16 @@ def evaluate_response(query: str, agent_response: str, ground_truth: str) -> Dic
}, },
) )
response.raise_for_status() response.raise_for_status()
response_json = response.json()
# Update cost of evaluation
input_tokens = response_json["usageMetadata"]["promptTokenCount"]
ouput_tokens = response_json["usageMetadata"]["candidatesTokenCount"]
cost = get_cost_of_chat_message(GEMINI_EVAL_MODEL, input_tokens, ouput_tokens)
# Parse evaluation response # Parse evaluation response
eval_response: dict[str, str] = json.loads( eval_response: dict[str, str] = json.loads(
clean_json(response.json()["candidates"][0]["content"]["parts"][0]["text"]) clean_json(response_json["candidates"][0]["content"]["parts"][0]["text"])
) )
decision = str(eval_response.get("decision", "")).upper() == "TRUE" decision = str(eval_response.get("decision", "")).upper() == "TRUE"
explanation = eval_response.get("explanation", "") explanation = eval_response.get("explanation", "")
@ -158,13 +194,14 @@ def evaluate_response(query: str, agent_response: str, ground_truth: str) -> Dic
if "503 Service Error" in explanation: if "503 Service Error" in explanation:
decision = None decision = None
# Extract decision and explanation from structured response # Extract decision and explanation from structured response
return decision, explanation return decision, explanation, cost
except Exception as e: except Exception as e:
logger.error(f"Error in evaluation: {e}") logger.error(f"Error in evaluation: {e}")
return None, f"Evaluation failed: {str(e)}" return None, f"Evaluation failed: {str(e)}", 0.0
def process_batch(batch, batch_start, results, dataset_length): def process_batch(batch, batch_start, results, dataset_length):
global running_cost
for idx, (prompt, answer, reasoning_type) in enumerate(batch): for idx, (prompt, answer, reasoning_type) in enumerate(batch):
current_index = batch_start + idx current_index = batch_start + idx
logger.info(f"Processing example: {current_index}/{dataset_length}") logger.info(f"Processing example: {current_index}/{dataset_length}")
@ -173,14 +210,16 @@ def process_batch(batch, batch_start, results, dataset_length):
prompt = f"/{KHOJ_MODE} {prompt}" if KHOJ_MODE and not prompt.startswith(f"/{KHOJ_MODE}") else prompt prompt = f"/{KHOJ_MODE} {prompt}" if KHOJ_MODE and not prompt.startswith(f"/{KHOJ_MODE}") else prompt
# Get agent response # Get agent response
agent_response = get_agent_response(prompt) response = get_agent_response(prompt)
agent_response = response["response"]
agent_usage = response["usage"]
# Evaluate response # Evaluate response
if is_none_or_empty(agent_response): if is_none_or_empty(agent_response):
decision = None decision = None
explanation = "Agent response is empty. This maybe due to a service error." explanation = "Agent response is empty. This maybe due to a service error."
else: else:
decision, explanation = evaluate_response(prompt, agent_response, answer) decision, explanation, eval_cost = evaluate_response(prompt, agent_response, answer)
# Store results # Store results
results.append( results.append(
@ -192,17 +231,38 @@ def process_batch(batch, batch_start, results, dataset_length):
"evaluation_decision": decision, "evaluation_decision": decision,
"evaluation_explanation": explanation, "evaluation_explanation": explanation,
"reasoning_type": reasoning_type, "reasoning_type": reasoning_type,
"usage": agent_usage,
} }
) )
# Log results # Update running cost
query_cost = float(agent_usage.get("cost", 0.0))
running_cost.add(query_cost + eval_cost)
# Update running accuracy
running_accuracy = 0.0
if decision is not None:
running_true_count.add(1) if decision == True else running_false_count.add(1)
running_accuracy = running_true_count.get() / (running_true_count.get() + running_false_count.get())
## Log results
decision_color = {True: "green", None: "blue", False: "red"}[decision] decision_color = {True: "green", None: "blue", False: "red"}[decision]
colored_decision = color_text(str(decision), decision_color) colored_decision = color_text(str(decision), decision_color)
logger.info( result_to_print = f"""
f"Decision: {colored_decision}\nQuestion: {prompt}\nExpected Answer: {answer}\nAgent Answer: {agent_response}\nExplanation: {explanation}\n" ---------
) Decision: {colored_decision}
Accuracy: {running_accuracy:.2%}
Question: {prompt}
Expected Answer: {answer}
Agent Answer: {agent_response}
Explanation: {explanation}
Cost: ${running_cost.get():.5f} (Query: ${query_cost:.5f}, Eval: ${eval_cost:.5f})
---------
"""
logger.info(dedent(result_to_print).lstrip())
time.sleep(SLEEP_SECONDS) # Rate limiting # Sleep between API calls to avoid rate limiting
time.sleep(SLEEP_SECONDS)
def color_text(text, color): def color_text(text, color):
@ -281,17 +341,18 @@ def main():
lambda x: (x == True).mean() lambda x: (x == True).mean()
) )
# Print summary # Collect summary
colored_accuracy = color_text(f"{accuracy:.2%}", "blue") colored_accuracy = color_text(f"{accuracy:.2%}", "blue")
logger.info(f"\nOverall Accuracy: {colored_accuracy}") colored_accuracy_str = f"Overall Accuracy: {colored_accuracy} on {args.dataset.title()} dataset."
logger.info(f"\nAccuracy by Reasoning Type:\n{reasoning_type_accuracy}") accuracy_str = f"Overall Accuracy: {accuracy:.2%} on {args.dataset}."
accuracy_by_reasoning = f"Accuracy by Reasoning Type:\n{reasoning_type_accuracy}"
# Save summary to file cost = f"Total Cost: ${running_cost.get():.5f}."
sample_type = f"Sampling Type: {SAMPLE_SIZE} samples." if SAMPLE_SIZE else "Whole dataset." sample_type = f"Sampling Type: {SAMPLE_SIZE} samples." if SAMPLE_SIZE else "Whole dataset."
sample_type += " Randomized." if RANDOMIZE else "" sample_type += " Randomized." if RANDOMIZE else ""
summary = ( logger.info(f"\n{colored_accuracy_str}\n\n{accuracy_by_reasoning}\n\n{cost}\n\n{sample_type}\n")
f"Overall Accuracy: {accuracy:.2%}\n\nAccuracy by Reasoning Type:\n{reasoning_type_accuracy}\n\n{sample_type}\n"
) # Save summary to file
summary = f"{accuracy_str}\n\n{accuracy_by_reasoning}\n\n{cost}\n\n{sample_type}\n"
summary_file = args.output.replace(".csv", ".txt") if args.output else None summary_file = args.output.replace(".csv", ".txt") if args.output else None
summary_file = ( summary_file = (
summary_file or f"{args.dataset}_evaluation_summary_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.txt" summary_file or f"{args.dataset}_evaluation_summary_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.txt"