diff --git a/tests/evals/eval.py b/tests/evals/eval.py index 9c018b76..30c842b2 100644 --- a/tests/evals/eval.py +++ b/tests/evals/eval.py @@ -6,13 +6,15 @@ import os import time from datetime import datetime from io import StringIO +from textwrap import dedent +from threading import Lock from typing import Any, Dict import pandas as pd import requests from datasets import Dataset, load_dataset -from khoj.utils.helpers import is_none_or_empty, timer +from khoj.utils.helpers import get_cost_of_chat_message, is_none_or_empty, timer # Configure root logger logging.basicConfig(level=logging.INFO, format="%(message)s") @@ -38,6 +40,28 @@ BATCH_SIZE = int( SLEEP_SECONDS = 3 if KHOJ_MODE == "general" else 1 # Sleep between API calls to avoid rate limiting +class Counter: + """Thread-safe counter for tracking metrics""" + + def __init__(self, value=0.0): + self.value = value + self.lock = Lock() + + def add(self, amount): + with self.lock: + self.value += amount + + def get(self): + with self.lock: + return self.value + + +# Track running metrics while evaluating +running_cost = Counter() +running_true_count = Counter(0) +running_false_count = Counter(0) + + def load_frames_dataset(): """ Load the Google FRAMES benchmark dataset from HuggingFace @@ -104,25 +128,31 @@ def load_simpleqa_dataset(): return None -def get_agent_response(prompt: str) -> str: +def get_agent_response(prompt: str) -> Dict[str, Any]: """Get response from the Khoj API""" + # Set headers + headers = {"Content-Type": "application/json"} + if not is_none_or_empty(KHOJ_API_KEY): + headers["Authorization"] = f"Bearer {KHOJ_API_KEY}" + try: response = requests.post( KHOJ_CHAT_API_URL, - headers={"Content-Type": "application/json", "Authorization": f"Bearer {KHOJ_API_KEY}"}, + headers=headers, json={ "q": prompt, "create_new": True, }, ) response.raise_for_status() - return response.json().get("response", "") + response_json = response.json() + return {"response": response_json.get("response", ""), "usage": response_json.get("usage", {})} except Exception as e: logger.error(f"Error getting agent response: {e}") - return "" + return {"response": "", "usage": {}} -def evaluate_response(query: str, agent_response: str, ground_truth: str) -> Dict[str, Any]: +def evaluate_response(query: str, agent_response: str, ground_truth: str) -> tuple[bool | None, str, float]: """Evaluate Khoj response against benchmark ground truth using Gemini""" evaluation_prompt = f""" Compare the following agent response with the ground truth answer. @@ -147,10 +177,16 @@ def evaluate_response(query: str, agent_response: str, ground_truth: str) -> Dic }, ) response.raise_for_status() + response_json = response.json() + + # Update cost of evaluation + input_tokens = response_json["usageMetadata"]["promptTokenCount"] + ouput_tokens = response_json["usageMetadata"]["candidatesTokenCount"] + cost = get_cost_of_chat_message(GEMINI_EVAL_MODEL, input_tokens, ouput_tokens) # Parse evaluation response eval_response: dict[str, str] = json.loads( - clean_json(response.json()["candidates"][0]["content"]["parts"][0]["text"]) + clean_json(response_json["candidates"][0]["content"]["parts"][0]["text"]) ) decision = str(eval_response.get("decision", "")).upper() == "TRUE" explanation = eval_response.get("explanation", "") @@ -158,13 +194,14 @@ def evaluate_response(query: str, agent_response: str, ground_truth: str) -> Dic if "503 Service Error" in explanation: decision = None # Extract decision and explanation from structured response - return decision, explanation + return decision, explanation, cost except Exception as e: logger.error(f"Error in evaluation: {e}") - return None, f"Evaluation failed: {str(e)}" + return None, f"Evaluation failed: {str(e)}", 0.0 def process_batch(batch, batch_start, results, dataset_length): + global running_cost for idx, (prompt, answer, reasoning_type) in enumerate(batch): current_index = batch_start + idx logger.info(f"Processing example: {current_index}/{dataset_length}") @@ -173,14 +210,16 @@ def process_batch(batch, batch_start, results, dataset_length): prompt = f"/{KHOJ_MODE} {prompt}" if KHOJ_MODE and not prompt.startswith(f"/{KHOJ_MODE}") else prompt # Get agent response - agent_response = get_agent_response(prompt) + response = get_agent_response(prompt) + agent_response = response["response"] + agent_usage = response["usage"] # Evaluate response if is_none_or_empty(agent_response): decision = None explanation = "Agent response is empty. This maybe due to a service error." else: - decision, explanation = evaluate_response(prompt, agent_response, answer) + decision, explanation, eval_cost = evaluate_response(prompt, agent_response, answer) # Store results results.append( @@ -192,17 +231,38 @@ def process_batch(batch, batch_start, results, dataset_length): "evaluation_decision": decision, "evaluation_explanation": explanation, "reasoning_type": reasoning_type, + "usage": agent_usage, } ) - # Log results + # Update running cost + query_cost = float(agent_usage.get("cost", 0.0)) + running_cost.add(query_cost + eval_cost) + + # Update running accuracy + running_accuracy = 0.0 + if decision is not None: + running_true_count.add(1) if decision == True else running_false_count.add(1) + running_accuracy = running_true_count.get() / (running_true_count.get() + running_false_count.get()) + + ## Log results decision_color = {True: "green", None: "blue", False: "red"}[decision] colored_decision = color_text(str(decision), decision_color) - logger.info( - f"Decision: {colored_decision}\nQuestion: {prompt}\nExpected Answer: {answer}\nAgent Answer: {agent_response}\nExplanation: {explanation}\n" - ) + result_to_print = f""" +--------- +Decision: {colored_decision} +Accuracy: {running_accuracy:.2%} +Question: {prompt} +Expected Answer: {answer} +Agent Answer: {agent_response} +Explanation: {explanation} +Cost: ${running_cost.get():.5f} (Query: ${query_cost:.5f}, Eval: ${eval_cost:.5f}) +--------- + """ + logger.info(dedent(result_to_print).lstrip()) - time.sleep(SLEEP_SECONDS) # Rate limiting + # Sleep between API calls to avoid rate limiting + time.sleep(SLEEP_SECONDS) def color_text(text, color): @@ -281,17 +341,18 @@ def main(): lambda x: (x == True).mean() ) - # Print summary + # Collect summary colored_accuracy = color_text(f"{accuracy:.2%}", "blue") - logger.info(f"\nOverall Accuracy: {colored_accuracy}") - logger.info(f"\nAccuracy by Reasoning Type:\n{reasoning_type_accuracy}") - - # Save summary to file + colored_accuracy_str = f"Overall Accuracy: {colored_accuracy} on {args.dataset.title()} dataset." + accuracy_str = f"Overall Accuracy: {accuracy:.2%} on {args.dataset}." + accuracy_by_reasoning = f"Accuracy by Reasoning Type:\n{reasoning_type_accuracy}" + cost = f"Total Cost: ${running_cost.get():.5f}." sample_type = f"Sampling Type: {SAMPLE_SIZE} samples." if SAMPLE_SIZE else "Whole dataset." sample_type += " Randomized." if RANDOMIZE else "" - summary = ( - f"Overall Accuracy: {accuracy:.2%}\n\nAccuracy by Reasoning Type:\n{reasoning_type_accuracy}\n\n{sample_type}\n" - ) + logger.info(f"\n{colored_accuracy_str}\n\n{accuracy_by_reasoning}\n\n{cost}\n\n{sample_type}\n") + + # Save summary to file + summary = f"{accuracy_str}\n\n{accuracy_by_reasoning}\n\n{cost}\n\n{sample_type}\n" summary_file = args.output.replace(".csv", ".txt") if args.output else None summary_file = ( summary_file or f"{args.dataset}_evaluation_summary_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.txt"