mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 15:38:55 +01:00
Track running costs & accuracy of eval runs in progress
Collect, display and store running costs & accuracy of eval run. This provides more insight into eval runs during execution instead of having to wait until the eval run completes.
This commit is contained in:
parent
c53c3db96b
commit
ed364fa90e
1 changed files with 85 additions and 24 deletions
|
@ -6,13 +6,15 @@ import os
|
||||||
import time
|
import time
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
|
from textwrap import dedent
|
||||||
|
from threading import Lock
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import requests
|
import requests
|
||||||
from datasets import Dataset, load_dataset
|
from datasets import Dataset, load_dataset
|
||||||
|
|
||||||
from khoj.utils.helpers import is_none_or_empty, timer
|
from khoj.utils.helpers import get_cost_of_chat_message, is_none_or_empty, timer
|
||||||
|
|
||||||
# Configure root logger
|
# Configure root logger
|
||||||
logging.basicConfig(level=logging.INFO, format="%(message)s")
|
logging.basicConfig(level=logging.INFO, format="%(message)s")
|
||||||
|
@ -38,6 +40,28 @@ BATCH_SIZE = int(
|
||||||
SLEEP_SECONDS = 3 if KHOJ_MODE == "general" else 1 # Sleep between API calls to avoid rate limiting
|
SLEEP_SECONDS = 3 if KHOJ_MODE == "general" else 1 # Sleep between API calls to avoid rate limiting
|
||||||
|
|
||||||
|
|
||||||
|
class Counter:
|
||||||
|
"""Thread-safe counter for tracking metrics"""
|
||||||
|
|
||||||
|
def __init__(self, value=0.0):
|
||||||
|
self.value = value
|
||||||
|
self.lock = Lock()
|
||||||
|
|
||||||
|
def add(self, amount):
|
||||||
|
with self.lock:
|
||||||
|
self.value += amount
|
||||||
|
|
||||||
|
def get(self):
|
||||||
|
with self.lock:
|
||||||
|
return self.value
|
||||||
|
|
||||||
|
|
||||||
|
# Track running metrics while evaluating
|
||||||
|
running_cost = Counter()
|
||||||
|
running_true_count = Counter(0)
|
||||||
|
running_false_count = Counter(0)
|
||||||
|
|
||||||
|
|
||||||
def load_frames_dataset():
|
def load_frames_dataset():
|
||||||
"""
|
"""
|
||||||
Load the Google FRAMES benchmark dataset from HuggingFace
|
Load the Google FRAMES benchmark dataset from HuggingFace
|
||||||
|
@ -104,25 +128,31 @@ def load_simpleqa_dataset():
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def get_agent_response(prompt: str) -> str:
|
def get_agent_response(prompt: str) -> Dict[str, Any]:
|
||||||
"""Get response from the Khoj API"""
|
"""Get response from the Khoj API"""
|
||||||
|
# Set headers
|
||||||
|
headers = {"Content-Type": "application/json"}
|
||||||
|
if not is_none_or_empty(KHOJ_API_KEY):
|
||||||
|
headers["Authorization"] = f"Bearer {KHOJ_API_KEY}"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
KHOJ_CHAT_API_URL,
|
KHOJ_CHAT_API_URL,
|
||||||
headers={"Content-Type": "application/json", "Authorization": f"Bearer {KHOJ_API_KEY}"},
|
headers=headers,
|
||||||
json={
|
json={
|
||||||
"q": prompt,
|
"q": prompt,
|
||||||
"create_new": True,
|
"create_new": True,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
return response.json().get("response", "")
|
response_json = response.json()
|
||||||
|
return {"response": response_json.get("response", ""), "usage": response_json.get("usage", {})}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error getting agent response: {e}")
|
logger.error(f"Error getting agent response: {e}")
|
||||||
return ""
|
return {"response": "", "usage": {}}
|
||||||
|
|
||||||
|
|
||||||
def evaluate_response(query: str, agent_response: str, ground_truth: str) -> Dict[str, Any]:
|
def evaluate_response(query: str, agent_response: str, ground_truth: str) -> tuple[bool | None, str, float]:
|
||||||
"""Evaluate Khoj response against benchmark ground truth using Gemini"""
|
"""Evaluate Khoj response against benchmark ground truth using Gemini"""
|
||||||
evaluation_prompt = f"""
|
evaluation_prompt = f"""
|
||||||
Compare the following agent response with the ground truth answer.
|
Compare the following agent response with the ground truth answer.
|
||||||
|
@ -147,10 +177,16 @@ def evaluate_response(query: str, agent_response: str, ground_truth: str) -> Dic
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
response_json = response.json()
|
||||||
|
|
||||||
|
# Update cost of evaluation
|
||||||
|
input_tokens = response_json["usageMetadata"]["promptTokenCount"]
|
||||||
|
ouput_tokens = response_json["usageMetadata"]["candidatesTokenCount"]
|
||||||
|
cost = get_cost_of_chat_message(GEMINI_EVAL_MODEL, input_tokens, ouput_tokens)
|
||||||
|
|
||||||
# Parse evaluation response
|
# Parse evaluation response
|
||||||
eval_response: dict[str, str] = json.loads(
|
eval_response: dict[str, str] = json.loads(
|
||||||
clean_json(response.json()["candidates"][0]["content"]["parts"][0]["text"])
|
clean_json(response_json["candidates"][0]["content"]["parts"][0]["text"])
|
||||||
)
|
)
|
||||||
decision = str(eval_response.get("decision", "")).upper() == "TRUE"
|
decision = str(eval_response.get("decision", "")).upper() == "TRUE"
|
||||||
explanation = eval_response.get("explanation", "")
|
explanation = eval_response.get("explanation", "")
|
||||||
|
@ -158,13 +194,14 @@ def evaluate_response(query: str, agent_response: str, ground_truth: str) -> Dic
|
||||||
if "503 Service Error" in explanation:
|
if "503 Service Error" in explanation:
|
||||||
decision = None
|
decision = None
|
||||||
# Extract decision and explanation from structured response
|
# Extract decision and explanation from structured response
|
||||||
return decision, explanation
|
return decision, explanation, cost
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in evaluation: {e}")
|
logger.error(f"Error in evaluation: {e}")
|
||||||
return None, f"Evaluation failed: {str(e)}"
|
return None, f"Evaluation failed: {str(e)}", 0.0
|
||||||
|
|
||||||
|
|
||||||
def process_batch(batch, batch_start, results, dataset_length):
|
def process_batch(batch, batch_start, results, dataset_length):
|
||||||
|
global running_cost
|
||||||
for idx, (prompt, answer, reasoning_type) in enumerate(batch):
|
for idx, (prompt, answer, reasoning_type) in enumerate(batch):
|
||||||
current_index = batch_start + idx
|
current_index = batch_start + idx
|
||||||
logger.info(f"Processing example: {current_index}/{dataset_length}")
|
logger.info(f"Processing example: {current_index}/{dataset_length}")
|
||||||
|
@ -173,14 +210,16 @@ def process_batch(batch, batch_start, results, dataset_length):
|
||||||
prompt = f"/{KHOJ_MODE} {prompt}" if KHOJ_MODE and not prompt.startswith(f"/{KHOJ_MODE}") else prompt
|
prompt = f"/{KHOJ_MODE} {prompt}" if KHOJ_MODE and not prompt.startswith(f"/{KHOJ_MODE}") else prompt
|
||||||
|
|
||||||
# Get agent response
|
# Get agent response
|
||||||
agent_response = get_agent_response(prompt)
|
response = get_agent_response(prompt)
|
||||||
|
agent_response = response["response"]
|
||||||
|
agent_usage = response["usage"]
|
||||||
|
|
||||||
# Evaluate response
|
# Evaluate response
|
||||||
if is_none_or_empty(agent_response):
|
if is_none_or_empty(agent_response):
|
||||||
decision = None
|
decision = None
|
||||||
explanation = "Agent response is empty. This maybe due to a service error."
|
explanation = "Agent response is empty. This maybe due to a service error."
|
||||||
else:
|
else:
|
||||||
decision, explanation = evaluate_response(prompt, agent_response, answer)
|
decision, explanation, eval_cost = evaluate_response(prompt, agent_response, answer)
|
||||||
|
|
||||||
# Store results
|
# Store results
|
||||||
results.append(
|
results.append(
|
||||||
|
@ -192,17 +231,38 @@ def process_batch(batch, batch_start, results, dataset_length):
|
||||||
"evaluation_decision": decision,
|
"evaluation_decision": decision,
|
||||||
"evaluation_explanation": explanation,
|
"evaluation_explanation": explanation,
|
||||||
"reasoning_type": reasoning_type,
|
"reasoning_type": reasoning_type,
|
||||||
|
"usage": agent_usage,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Log results
|
# Update running cost
|
||||||
|
query_cost = float(agent_usage.get("cost", 0.0))
|
||||||
|
running_cost.add(query_cost + eval_cost)
|
||||||
|
|
||||||
|
# Update running accuracy
|
||||||
|
running_accuracy = 0.0
|
||||||
|
if decision is not None:
|
||||||
|
running_true_count.add(1) if decision == True else running_false_count.add(1)
|
||||||
|
running_accuracy = running_true_count.get() / (running_true_count.get() + running_false_count.get())
|
||||||
|
|
||||||
|
## Log results
|
||||||
decision_color = {True: "green", None: "blue", False: "red"}[decision]
|
decision_color = {True: "green", None: "blue", False: "red"}[decision]
|
||||||
colored_decision = color_text(str(decision), decision_color)
|
colored_decision = color_text(str(decision), decision_color)
|
||||||
logger.info(
|
result_to_print = f"""
|
||||||
f"Decision: {colored_decision}\nQuestion: {prompt}\nExpected Answer: {answer}\nAgent Answer: {agent_response}\nExplanation: {explanation}\n"
|
---------
|
||||||
)
|
Decision: {colored_decision}
|
||||||
|
Accuracy: {running_accuracy:.2%}
|
||||||
|
Question: {prompt}
|
||||||
|
Expected Answer: {answer}
|
||||||
|
Agent Answer: {agent_response}
|
||||||
|
Explanation: {explanation}
|
||||||
|
Cost: ${running_cost.get():.5f} (Query: ${query_cost:.5f}, Eval: ${eval_cost:.5f})
|
||||||
|
---------
|
||||||
|
"""
|
||||||
|
logger.info(dedent(result_to_print).lstrip())
|
||||||
|
|
||||||
time.sleep(SLEEP_SECONDS) # Rate limiting
|
# Sleep between API calls to avoid rate limiting
|
||||||
|
time.sleep(SLEEP_SECONDS)
|
||||||
|
|
||||||
|
|
||||||
def color_text(text, color):
|
def color_text(text, color):
|
||||||
|
@ -281,17 +341,18 @@ def main():
|
||||||
lambda x: (x == True).mean()
|
lambda x: (x == True).mean()
|
||||||
)
|
)
|
||||||
|
|
||||||
# Print summary
|
# Collect summary
|
||||||
colored_accuracy = color_text(f"{accuracy:.2%}", "blue")
|
colored_accuracy = color_text(f"{accuracy:.2%}", "blue")
|
||||||
logger.info(f"\nOverall Accuracy: {colored_accuracy}")
|
colored_accuracy_str = f"Overall Accuracy: {colored_accuracy} on {args.dataset.title()} dataset."
|
||||||
logger.info(f"\nAccuracy by Reasoning Type:\n{reasoning_type_accuracy}")
|
accuracy_str = f"Overall Accuracy: {accuracy:.2%} on {args.dataset}."
|
||||||
|
accuracy_by_reasoning = f"Accuracy by Reasoning Type:\n{reasoning_type_accuracy}"
|
||||||
# Save summary to file
|
cost = f"Total Cost: ${running_cost.get():.5f}."
|
||||||
sample_type = f"Sampling Type: {SAMPLE_SIZE} samples." if SAMPLE_SIZE else "Whole dataset."
|
sample_type = f"Sampling Type: {SAMPLE_SIZE} samples." if SAMPLE_SIZE else "Whole dataset."
|
||||||
sample_type += " Randomized." if RANDOMIZE else ""
|
sample_type += " Randomized." if RANDOMIZE else ""
|
||||||
summary = (
|
logger.info(f"\n{colored_accuracy_str}\n\n{accuracy_by_reasoning}\n\n{cost}\n\n{sample_type}\n")
|
||||||
f"Overall Accuracy: {accuracy:.2%}\n\nAccuracy by Reasoning Type:\n{reasoning_type_accuracy}\n\n{sample_type}\n"
|
|
||||||
)
|
# Save summary to file
|
||||||
|
summary = f"{accuracy_str}\n\n{accuracy_by_reasoning}\n\n{cost}\n\n{sample_type}\n"
|
||||||
summary_file = args.output.replace(".csv", ".txt") if args.output else None
|
summary_file = args.output.replace(".csv", ".txt") if args.output else None
|
||||||
summary_file = (
|
summary_file = (
|
||||||
summary_file or f"{args.dataset}_evaluation_summary_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.txt"
|
summary_file or f"{args.dataset}_evaluation_summary_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.txt"
|
||||||
|
|
Loading…
Reference in a new issue