mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 15:38:55 +01:00
Track running costs & accuracy of eval runs in progress
Collect, display and store running costs & accuracy of eval run. This provides more insight into eval runs during execution instead of having to wait until the eval run completes.
This commit is contained in:
parent
c53c3db96b
commit
ed364fa90e
1 changed files with 85 additions and 24 deletions
|
@ -6,13 +6,15 @@ import os
|
|||
import time
|
||||
from datetime import datetime
|
||||
from io import StringIO
|
||||
from textwrap import dedent
|
||||
from threading import Lock
|
||||
from typing import Any, Dict
|
||||
|
||||
import pandas as pd
|
||||
import requests
|
||||
from datasets import Dataset, load_dataset
|
||||
|
||||
from khoj.utils.helpers import is_none_or_empty, timer
|
||||
from khoj.utils.helpers import get_cost_of_chat_message, is_none_or_empty, timer
|
||||
|
||||
# Configure root logger
|
||||
logging.basicConfig(level=logging.INFO, format="%(message)s")
|
||||
|
@ -38,6 +40,28 @@ BATCH_SIZE = int(
|
|||
SLEEP_SECONDS = 3 if KHOJ_MODE == "general" else 1 # Sleep between API calls to avoid rate limiting
|
||||
|
||||
|
||||
class Counter:
|
||||
"""Thread-safe counter for tracking metrics"""
|
||||
|
||||
def __init__(self, value=0.0):
|
||||
self.value = value
|
||||
self.lock = Lock()
|
||||
|
||||
def add(self, amount):
|
||||
with self.lock:
|
||||
self.value += amount
|
||||
|
||||
def get(self):
|
||||
with self.lock:
|
||||
return self.value
|
||||
|
||||
|
||||
# Track running metrics while evaluating
|
||||
running_cost = Counter()
|
||||
running_true_count = Counter(0)
|
||||
running_false_count = Counter(0)
|
||||
|
||||
|
||||
def load_frames_dataset():
|
||||
"""
|
||||
Load the Google FRAMES benchmark dataset from HuggingFace
|
||||
|
@ -104,25 +128,31 @@ def load_simpleqa_dataset():
|
|||
return None
|
||||
|
||||
|
||||
def get_agent_response(prompt: str) -> str:
|
||||
def get_agent_response(prompt: str) -> Dict[str, Any]:
|
||||
"""Get response from the Khoj API"""
|
||||
# Set headers
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if not is_none_or_empty(KHOJ_API_KEY):
|
||||
headers["Authorization"] = f"Bearer {KHOJ_API_KEY}"
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
KHOJ_CHAT_API_URL,
|
||||
headers={"Content-Type": "application/json", "Authorization": f"Bearer {KHOJ_API_KEY}"},
|
||||
headers=headers,
|
||||
json={
|
||||
"q": prompt,
|
||||
"create_new": True,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json().get("response", "")
|
||||
response_json = response.json()
|
||||
return {"response": response_json.get("response", ""), "usage": response_json.get("usage", {})}
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting agent response: {e}")
|
||||
return ""
|
||||
return {"response": "", "usage": {}}
|
||||
|
||||
|
||||
def evaluate_response(query: str, agent_response: str, ground_truth: str) -> Dict[str, Any]:
|
||||
def evaluate_response(query: str, agent_response: str, ground_truth: str) -> tuple[bool | None, str, float]:
|
||||
"""Evaluate Khoj response against benchmark ground truth using Gemini"""
|
||||
evaluation_prompt = f"""
|
||||
Compare the following agent response with the ground truth answer.
|
||||
|
@ -147,10 +177,16 @@ def evaluate_response(query: str, agent_response: str, ground_truth: str) -> Dic
|
|||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
response_json = response.json()
|
||||
|
||||
# Update cost of evaluation
|
||||
input_tokens = response_json["usageMetadata"]["promptTokenCount"]
|
||||
ouput_tokens = response_json["usageMetadata"]["candidatesTokenCount"]
|
||||
cost = get_cost_of_chat_message(GEMINI_EVAL_MODEL, input_tokens, ouput_tokens)
|
||||
|
||||
# Parse evaluation response
|
||||
eval_response: dict[str, str] = json.loads(
|
||||
clean_json(response.json()["candidates"][0]["content"]["parts"][0]["text"])
|
||||
clean_json(response_json["candidates"][0]["content"]["parts"][0]["text"])
|
||||
)
|
||||
decision = str(eval_response.get("decision", "")).upper() == "TRUE"
|
||||
explanation = eval_response.get("explanation", "")
|
||||
|
@ -158,13 +194,14 @@ def evaluate_response(query: str, agent_response: str, ground_truth: str) -> Dic
|
|||
if "503 Service Error" in explanation:
|
||||
decision = None
|
||||
# Extract decision and explanation from structured response
|
||||
return decision, explanation
|
||||
return decision, explanation, cost
|
||||
except Exception as e:
|
||||
logger.error(f"Error in evaluation: {e}")
|
||||
return None, f"Evaluation failed: {str(e)}"
|
||||
return None, f"Evaluation failed: {str(e)}", 0.0
|
||||
|
||||
|
||||
def process_batch(batch, batch_start, results, dataset_length):
|
||||
global running_cost
|
||||
for idx, (prompt, answer, reasoning_type) in enumerate(batch):
|
||||
current_index = batch_start + idx
|
||||
logger.info(f"Processing example: {current_index}/{dataset_length}")
|
||||
|
@ -173,14 +210,16 @@ def process_batch(batch, batch_start, results, dataset_length):
|
|||
prompt = f"/{KHOJ_MODE} {prompt}" if KHOJ_MODE and not prompt.startswith(f"/{KHOJ_MODE}") else prompt
|
||||
|
||||
# Get agent response
|
||||
agent_response = get_agent_response(prompt)
|
||||
response = get_agent_response(prompt)
|
||||
agent_response = response["response"]
|
||||
agent_usage = response["usage"]
|
||||
|
||||
# Evaluate response
|
||||
if is_none_or_empty(agent_response):
|
||||
decision = None
|
||||
explanation = "Agent response is empty. This maybe due to a service error."
|
||||
else:
|
||||
decision, explanation = evaluate_response(prompt, agent_response, answer)
|
||||
decision, explanation, eval_cost = evaluate_response(prompt, agent_response, answer)
|
||||
|
||||
# Store results
|
||||
results.append(
|
||||
|
@ -192,17 +231,38 @@ def process_batch(batch, batch_start, results, dataset_length):
|
|||
"evaluation_decision": decision,
|
||||
"evaluation_explanation": explanation,
|
||||
"reasoning_type": reasoning_type,
|
||||
"usage": agent_usage,
|
||||
}
|
||||
)
|
||||
|
||||
# Log results
|
||||
# Update running cost
|
||||
query_cost = float(agent_usage.get("cost", 0.0))
|
||||
running_cost.add(query_cost + eval_cost)
|
||||
|
||||
# Update running accuracy
|
||||
running_accuracy = 0.0
|
||||
if decision is not None:
|
||||
running_true_count.add(1) if decision == True else running_false_count.add(1)
|
||||
running_accuracy = running_true_count.get() / (running_true_count.get() + running_false_count.get())
|
||||
|
||||
## Log results
|
||||
decision_color = {True: "green", None: "blue", False: "red"}[decision]
|
||||
colored_decision = color_text(str(decision), decision_color)
|
||||
logger.info(
|
||||
f"Decision: {colored_decision}\nQuestion: {prompt}\nExpected Answer: {answer}\nAgent Answer: {agent_response}\nExplanation: {explanation}\n"
|
||||
)
|
||||
result_to_print = f"""
|
||||
---------
|
||||
Decision: {colored_decision}
|
||||
Accuracy: {running_accuracy:.2%}
|
||||
Question: {prompt}
|
||||
Expected Answer: {answer}
|
||||
Agent Answer: {agent_response}
|
||||
Explanation: {explanation}
|
||||
Cost: ${running_cost.get():.5f} (Query: ${query_cost:.5f}, Eval: ${eval_cost:.5f})
|
||||
---------
|
||||
"""
|
||||
logger.info(dedent(result_to_print).lstrip())
|
||||
|
||||
time.sleep(SLEEP_SECONDS) # Rate limiting
|
||||
# Sleep between API calls to avoid rate limiting
|
||||
time.sleep(SLEEP_SECONDS)
|
||||
|
||||
|
||||
def color_text(text, color):
|
||||
|
@ -281,17 +341,18 @@ def main():
|
|||
lambda x: (x == True).mean()
|
||||
)
|
||||
|
||||
# Print summary
|
||||
# Collect summary
|
||||
colored_accuracy = color_text(f"{accuracy:.2%}", "blue")
|
||||
logger.info(f"\nOverall Accuracy: {colored_accuracy}")
|
||||
logger.info(f"\nAccuracy by Reasoning Type:\n{reasoning_type_accuracy}")
|
||||
|
||||
# Save summary to file
|
||||
colored_accuracy_str = f"Overall Accuracy: {colored_accuracy} on {args.dataset.title()} dataset."
|
||||
accuracy_str = f"Overall Accuracy: {accuracy:.2%} on {args.dataset}."
|
||||
accuracy_by_reasoning = f"Accuracy by Reasoning Type:\n{reasoning_type_accuracy}"
|
||||
cost = f"Total Cost: ${running_cost.get():.5f}."
|
||||
sample_type = f"Sampling Type: {SAMPLE_SIZE} samples." if SAMPLE_SIZE else "Whole dataset."
|
||||
sample_type += " Randomized." if RANDOMIZE else ""
|
||||
summary = (
|
||||
f"Overall Accuracy: {accuracy:.2%}\n\nAccuracy by Reasoning Type:\n{reasoning_type_accuracy}\n\n{sample_type}\n"
|
||||
)
|
||||
logger.info(f"\n{colored_accuracy_str}\n\n{accuracy_by_reasoning}\n\n{cost}\n\n{sample_type}\n")
|
||||
|
||||
# Save summary to file
|
||||
summary = f"{accuracy_str}\n\n{accuracy_by_reasoning}\n\n{cost}\n\n{sample_type}\n"
|
||||
summary_file = args.output.replace(".csv", ".txt") if args.output else None
|
||||
summary_file = (
|
||||
summary_file or f"{args.dataset}_evaluation_summary_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.txt"
|
||||
|
|
Loading…
Reference in a new issue