diff --git a/tests/eval_frames.py b/tests/eval_frames.py index 7e469722..9998d4c9 100644 --- a/tests/eval_frames.py +++ b/tests/eval_frames.py @@ -9,7 +9,7 @@ from typing import Any, Dict import pandas as pd import requests -from datasets import load_dataset +from datasets import Dataset, load_dataset from khoj.utils.helpers import is_none_or_empty, timer @@ -48,6 +48,38 @@ def load_frames_dataset(): return None +def load_talc_dataset(): + """ + Load the TALC dataset from Github. + + Normalize it into the FRAMES benchmark structure and the HuggingFace Dataset format. + """ + try: + # Load TALC search benchmark from Github + raw_url = "https://raw.githubusercontent.com/Talc-AI/search-bench/3fd5b0858e2effa4c1578c7d046bee0a3895c488/data/searchbench_08_30_2024.jsonl" + response = requests.get(raw_url) + response.raise_for_status() + + # Parse benchmark from raw JSONL response + jsonl_data = [json.loads(line) for line in response.text.splitlines()] + + # Rename keys to match FRAMES format + formatted_data = [ + {"Prompt": d["question"], "Answer": d["expected_answer"], "reasoning_types": "talc"} for d in jsonl_data + ] + + # Convert benchmark to HF Dataset + dataset = Dataset.from_list(formatted_data) + dataset = dataset.shuffle() if RANDOMIZE else dataset + dataset = dataset.select(range(int(SAMPLE_SIZE))) if SAMPLE_SIZE else dataset + + return dataset + + except Exception as e: + logger.error(f"Error loading dataset: {e}") + return None + + def get_agent_response(prompt: str) -> str: """Get response from the Khoj API""" try: @@ -153,7 +185,7 @@ def color_text(text, color): colors = { "red": "\033[91m", # Bright red "green": "\033[32m", # Standard green - "blue": "\033[94m", # Bright blue + "blue": "\033[34m", # Bright blue "reset": "\033[0m", } return f"{colors[color]}{text}{colors['reset']}" @@ -165,12 +197,19 @@ def clean_json(response: str): def parse_args(): - parser = argparse.ArgumentParser(description="Evaluate Khoj on the Google FRAMES benchmark.") + parser = argparse.ArgumentParser(description="Evaluate Khoj on a supported benchmark.") parser.add_argument( "--output", "-o", default=None, - help="Path to store evaluation results CSV (default: frames_evaluation_results_[datetime].csv)", + help="Path to store evaluation results CSV (default: [benchmark]_evaluation_results_[datetime].csv)", + ) + parser.add_argument( + "--dataset", + "-d", + default="frames", + choices=["frames", "talc"], + help="Dataset to use for evaluation (default: frames)", ) return parser.parse_args() @@ -178,10 +217,14 @@ def parse_args(): def main(): # Initialize variables args = parse_args() + dataset = None # Load dataset - with timer("Loaded dataset in", logger): - dataset = load_frames_dataset() + with timer(f"Loaded {args.dataset} dataset in", logger): + if args.dataset == "frames": + dataset = load_frames_dataset() + elif args.dataset == "talc": + dataset = load_talc_dataset() if dataset is None: return @@ -215,18 +258,19 @@ def main(): ) # Print summary - logger.info(f"\nOverall Accuracy: {accuracy:.2%}") + colored_accuracy = color_text(f"{accuracy:.2%}", "blue") + logger.info(f"\nOverall Accuracy: {colored_accuracy}") logger.info(f"\nAccuracy by Reasoning Type:\n{reasoning_type_accuracy}") # Save results - output_file = args.output or f"frames_evaluation_results_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.csv" + output_file = args.output or f"{args.dataset}_evaluation_results_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.csv" df.to_csv(output_file, index=False) logger.info(f"Results saved to {output_file}") if __name__ == "__main__": """ - Evaluate Khoj on the Google FRAMES benchmark. + Evaluate Khoj on supported benchmarks. Response are evaluated by GEMINI_EVAL_MODEL (default: gemini-pro-1.5-002). Khoj should be running at KHOJ_URL (default: http://localhost:42110). @@ -236,7 +280,7 @@ if __name__ == "__main__": Run the script using the following command: KHOJ_MODE="research" GEMINI_API_KEY="" python eval_frames.py """ - logger.info(f"{datetime.now()} - Begin Quizzing Khoj on the FRAMES benchmark.") + logger.info(f"{datetime.now()} - Begin Quizzing Khoj.") with timer("Ran eval script in", logger, log_level=logging.INFO): main() - logger.info(f"{datetime.now()} - End Quizzing Khoj on the FRAMES benchmark.") + logger.info(f"{datetime.now()} - End Quizzing Khoj.")