mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 15:38:55 +01:00
Enable evaluation Khoj on the Talc Search Bench using Eval script
- Just load the raw jsonl from Github and normalize it into FRAMES format - Color printed accuracy in eval script to blue for readability
This commit is contained in:
parent
8e009f48ce
commit
9fc44f1a7f
1 changed files with 55 additions and 11 deletions
|
@ -9,7 +9,7 @@ from typing import Any, Dict
|
|||
|
||||
import pandas as pd
|
||||
import requests
|
||||
from datasets import load_dataset
|
||||
from datasets import Dataset, load_dataset
|
||||
|
||||
from khoj.utils.helpers import is_none_or_empty, timer
|
||||
|
||||
|
@ -48,6 +48,38 @@ def load_frames_dataset():
|
|||
return None
|
||||
|
||||
|
||||
def load_talc_dataset():
|
||||
"""
|
||||
Load the TALC dataset from Github.
|
||||
|
||||
Normalize it into the FRAMES benchmark structure and the HuggingFace Dataset format.
|
||||
"""
|
||||
try:
|
||||
# Load TALC search benchmark from Github
|
||||
raw_url = "https://raw.githubusercontent.com/Talc-AI/search-bench/3fd5b0858e2effa4c1578c7d046bee0a3895c488/data/searchbench_08_30_2024.jsonl"
|
||||
response = requests.get(raw_url)
|
||||
response.raise_for_status()
|
||||
|
||||
# Parse benchmark from raw JSONL response
|
||||
jsonl_data = [json.loads(line) for line in response.text.splitlines()]
|
||||
|
||||
# Rename keys to match FRAMES format
|
||||
formatted_data = [
|
||||
{"Prompt": d["question"], "Answer": d["expected_answer"], "reasoning_types": "talc"} for d in jsonl_data
|
||||
]
|
||||
|
||||
# Convert benchmark to HF Dataset
|
||||
dataset = Dataset.from_list(formatted_data)
|
||||
dataset = dataset.shuffle() if RANDOMIZE else dataset
|
||||
dataset = dataset.select(range(int(SAMPLE_SIZE))) if SAMPLE_SIZE else dataset
|
||||
|
||||
return dataset
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading dataset: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def get_agent_response(prompt: str) -> str:
|
||||
"""Get response from the Khoj API"""
|
||||
try:
|
||||
|
@ -153,7 +185,7 @@ def color_text(text, color):
|
|||
colors = {
|
||||
"red": "\033[91m", # Bright red
|
||||
"green": "\033[32m", # Standard green
|
||||
"blue": "\033[94m", # Bright blue
|
||||
"blue": "\033[34m", # Bright blue
|
||||
"reset": "\033[0m",
|
||||
}
|
||||
return f"{colors[color]}{text}{colors['reset']}"
|
||||
|
@ -165,12 +197,19 @@ def clean_json(response: str):
|
|||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Evaluate Khoj on the Google FRAMES benchmark.")
|
||||
parser = argparse.ArgumentParser(description="Evaluate Khoj on a supported benchmark.")
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
"-o",
|
||||
default=None,
|
||||
help="Path to store evaluation results CSV (default: frames_evaluation_results_[datetime].csv)",
|
||||
help="Path to store evaluation results CSV (default: [benchmark]_evaluation_results_[datetime].csv)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
"-d",
|
||||
default="frames",
|
||||
choices=["frames", "talc"],
|
||||
help="Dataset to use for evaluation (default: frames)",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
@ -178,10 +217,14 @@ def parse_args():
|
|||
def main():
|
||||
# Initialize variables
|
||||
args = parse_args()
|
||||
dataset = None
|
||||
|
||||
# Load dataset
|
||||
with timer("Loaded dataset in", logger):
|
||||
dataset = load_frames_dataset()
|
||||
with timer(f"Loaded {args.dataset} dataset in", logger):
|
||||
if args.dataset == "frames":
|
||||
dataset = load_frames_dataset()
|
||||
elif args.dataset == "talc":
|
||||
dataset = load_talc_dataset()
|
||||
if dataset is None:
|
||||
return
|
||||
|
||||
|
@ -215,18 +258,19 @@ def main():
|
|||
)
|
||||
|
||||
# Print summary
|
||||
logger.info(f"\nOverall Accuracy: {accuracy:.2%}")
|
||||
colored_accuracy = color_text(f"{accuracy:.2%}", "blue")
|
||||
logger.info(f"\nOverall Accuracy: {colored_accuracy}")
|
||||
logger.info(f"\nAccuracy by Reasoning Type:\n{reasoning_type_accuracy}")
|
||||
|
||||
# Save results
|
||||
output_file = args.output or f"frames_evaluation_results_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.csv"
|
||||
output_file = args.output or f"{args.dataset}_evaluation_results_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.csv"
|
||||
df.to_csv(output_file, index=False)
|
||||
logger.info(f"Results saved to {output_file}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
Evaluate Khoj on the Google FRAMES benchmark.
|
||||
Evaluate Khoj on supported benchmarks.
|
||||
Response are evaluated by GEMINI_EVAL_MODEL (default: gemini-pro-1.5-002).
|
||||
|
||||
Khoj should be running at KHOJ_URL (default: http://localhost:42110).
|
||||
|
@ -236,7 +280,7 @@ if __name__ == "__main__":
|
|||
Run the script using the following command:
|
||||
KHOJ_MODE="research" GEMINI_API_KEY="<your_gemini_api_key>" python eval_frames.py
|
||||
"""
|
||||
logger.info(f"{datetime.now()} - Begin Quizzing Khoj on the FRAMES benchmark.")
|
||||
logger.info(f"{datetime.now()} - Begin Quizzing Khoj.")
|
||||
with timer("Ran eval script in", logger, log_level=logging.INFO):
|
||||
main()
|
||||
logger.info(f"{datetime.now()} - End Quizzing Khoj on the FRAMES benchmark.")
|
||||
logger.info(f"{datetime.now()} - End Quizzing Khoj.")
|
||||
|
|
Loading…
Reference in a new issue