mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-30 10:53:02 +01:00
Enable evaluation Khoj on the Talc Search Bench using Eval script
- Just load the raw jsonl from Github and normalize it into FRAMES format - Color printed accuracy in eval script to blue for readability
This commit is contained in:
parent
8e009f48ce
commit
9fc44f1a7f
1 changed files with 55 additions and 11 deletions
|
@ -9,7 +9,7 @@ from typing import Any, Dict
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import requests
|
import requests
|
||||||
from datasets import load_dataset
|
from datasets import Dataset, load_dataset
|
||||||
|
|
||||||
from khoj.utils.helpers import is_none_or_empty, timer
|
from khoj.utils.helpers import is_none_or_empty, timer
|
||||||
|
|
||||||
|
@ -48,6 +48,38 @@ def load_frames_dataset():
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def load_talc_dataset():
|
||||||
|
"""
|
||||||
|
Load the TALC dataset from Github.
|
||||||
|
|
||||||
|
Normalize it into the FRAMES benchmark structure and the HuggingFace Dataset format.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Load TALC search benchmark from Github
|
||||||
|
raw_url = "https://raw.githubusercontent.com/Talc-AI/search-bench/3fd5b0858e2effa4c1578c7d046bee0a3895c488/data/searchbench_08_30_2024.jsonl"
|
||||||
|
response = requests.get(raw_url)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
# Parse benchmark from raw JSONL response
|
||||||
|
jsonl_data = [json.loads(line) for line in response.text.splitlines()]
|
||||||
|
|
||||||
|
# Rename keys to match FRAMES format
|
||||||
|
formatted_data = [
|
||||||
|
{"Prompt": d["question"], "Answer": d["expected_answer"], "reasoning_types": "talc"} for d in jsonl_data
|
||||||
|
]
|
||||||
|
|
||||||
|
# Convert benchmark to HF Dataset
|
||||||
|
dataset = Dataset.from_list(formatted_data)
|
||||||
|
dataset = dataset.shuffle() if RANDOMIZE else dataset
|
||||||
|
dataset = dataset.select(range(int(SAMPLE_SIZE))) if SAMPLE_SIZE else dataset
|
||||||
|
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error loading dataset: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def get_agent_response(prompt: str) -> str:
|
def get_agent_response(prompt: str) -> str:
|
||||||
"""Get response from the Khoj API"""
|
"""Get response from the Khoj API"""
|
||||||
try:
|
try:
|
||||||
|
@ -153,7 +185,7 @@ def color_text(text, color):
|
||||||
colors = {
|
colors = {
|
||||||
"red": "\033[91m", # Bright red
|
"red": "\033[91m", # Bright red
|
||||||
"green": "\033[32m", # Standard green
|
"green": "\033[32m", # Standard green
|
||||||
"blue": "\033[94m", # Bright blue
|
"blue": "\033[34m", # Bright blue
|
||||||
"reset": "\033[0m",
|
"reset": "\033[0m",
|
||||||
}
|
}
|
||||||
return f"{colors[color]}{text}{colors['reset']}"
|
return f"{colors[color]}{text}{colors['reset']}"
|
||||||
|
@ -165,12 +197,19 @@ def clean_json(response: str):
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
parser = argparse.ArgumentParser(description="Evaluate Khoj on the Google FRAMES benchmark.")
|
parser = argparse.ArgumentParser(description="Evaluate Khoj on a supported benchmark.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--output",
|
"--output",
|
||||||
"-o",
|
"-o",
|
||||||
default=None,
|
default=None,
|
||||||
help="Path to store evaluation results CSV (default: frames_evaluation_results_[datetime].csv)",
|
help="Path to store evaluation results CSV (default: [benchmark]_evaluation_results_[datetime].csv)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataset",
|
||||||
|
"-d",
|
||||||
|
default="frames",
|
||||||
|
choices=["frames", "talc"],
|
||||||
|
help="Dataset to use for evaluation (default: frames)",
|
||||||
)
|
)
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
@ -178,10 +217,14 @@ def parse_args():
|
||||||
def main():
|
def main():
|
||||||
# Initialize variables
|
# Initialize variables
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
dataset = None
|
||||||
|
|
||||||
# Load dataset
|
# Load dataset
|
||||||
with timer("Loaded dataset in", logger):
|
with timer(f"Loaded {args.dataset} dataset in", logger):
|
||||||
|
if args.dataset == "frames":
|
||||||
dataset = load_frames_dataset()
|
dataset = load_frames_dataset()
|
||||||
|
elif args.dataset == "talc":
|
||||||
|
dataset = load_talc_dataset()
|
||||||
if dataset is None:
|
if dataset is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -215,18 +258,19 @@ def main():
|
||||||
)
|
)
|
||||||
|
|
||||||
# Print summary
|
# Print summary
|
||||||
logger.info(f"\nOverall Accuracy: {accuracy:.2%}")
|
colored_accuracy = color_text(f"{accuracy:.2%}", "blue")
|
||||||
|
logger.info(f"\nOverall Accuracy: {colored_accuracy}")
|
||||||
logger.info(f"\nAccuracy by Reasoning Type:\n{reasoning_type_accuracy}")
|
logger.info(f"\nAccuracy by Reasoning Type:\n{reasoning_type_accuracy}")
|
||||||
|
|
||||||
# Save results
|
# Save results
|
||||||
output_file = args.output or f"frames_evaluation_results_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.csv"
|
output_file = args.output or f"{args.dataset}_evaluation_results_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.csv"
|
||||||
df.to_csv(output_file, index=False)
|
df.to_csv(output_file, index=False)
|
||||||
logger.info(f"Results saved to {output_file}")
|
logger.info(f"Results saved to {output_file}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
"""
|
"""
|
||||||
Evaluate Khoj on the Google FRAMES benchmark.
|
Evaluate Khoj on supported benchmarks.
|
||||||
Response are evaluated by GEMINI_EVAL_MODEL (default: gemini-pro-1.5-002).
|
Response are evaluated by GEMINI_EVAL_MODEL (default: gemini-pro-1.5-002).
|
||||||
|
|
||||||
Khoj should be running at KHOJ_URL (default: http://localhost:42110).
|
Khoj should be running at KHOJ_URL (default: http://localhost:42110).
|
||||||
|
@ -236,7 +280,7 @@ if __name__ == "__main__":
|
||||||
Run the script using the following command:
|
Run the script using the following command:
|
||||||
KHOJ_MODE="research" GEMINI_API_KEY="<your_gemini_api_key>" python eval_frames.py
|
KHOJ_MODE="research" GEMINI_API_KEY="<your_gemini_api_key>" python eval_frames.py
|
||||||
"""
|
"""
|
||||||
logger.info(f"{datetime.now()} - Begin Quizzing Khoj on the FRAMES benchmark.")
|
logger.info(f"{datetime.now()} - Begin Quizzing Khoj.")
|
||||||
with timer("Ran eval script in", logger, log_level=logging.INFO):
|
with timer("Ran eval script in", logger, log_level=logging.INFO):
|
||||||
main()
|
main()
|
||||||
logger.info(f"{datetime.now()} - End Quizzing Khoj on the FRAMES benchmark.")
|
logger.info(f"{datetime.now()} - End Quizzing Khoj.")
|
||||||
|
|
Loading…
Reference in a new issue