2024-11-13 12:13:36 +01:00
|
|
|
import argparse
|
2024-11-02 12:58:03 +01:00
|
|
|
import concurrent.futures
|
2024-11-02 10:38:26 +01:00
|
|
|
import json
|
2024-11-03 02:20:42 +01:00
|
|
|
import logging
|
2024-11-02 10:38:26 +01:00
|
|
|
import os
|
|
|
|
import time
|
2024-11-13 12:13:36 +01:00
|
|
|
from datetime import datetime
|
2024-11-15 00:55:00 +01:00
|
|
|
from io import StringIO
|
2024-11-02 10:38:26 +01:00
|
|
|
from typing import Any, Dict
|
|
|
|
|
|
|
|
import pandas as pd
|
|
|
|
import requests
|
2024-11-14 05:02:56 +01:00
|
|
|
from datasets import Dataset, load_dataset
|
2024-11-02 10:38:26 +01:00
|
|
|
|
2024-11-13 12:13:36 +01:00
|
|
|
from khoj.utils.helpers import is_none_or_empty, timer
|
2024-11-03 02:20:42 +01:00
|
|
|
|
|
|
|
# Configure root logger
|
2024-11-13 12:13:36 +01:00
|
|
|
logging.basicConfig(level=logging.INFO, format="%(message)s")
|
2024-11-03 02:20:42 +01:00
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
2024-11-02 10:38:26 +01:00
|
|
|
# Configuration
|
|
|
|
KHOJ_URL = os.getenv("KHOJ_URL", "http://localhost:42110")
|
|
|
|
KHOJ_CHAT_API_URL = f"{KHOJ_URL}/api/chat"
|
|
|
|
KHOJ_API_KEY = os.getenv("KHOJ_API_KEY")
|
2024-11-18 11:19:30 +01:00
|
|
|
KHOJ_MODE = os.getenv("KHOJ_MODE", "default") # E.g research, general, notes etc.
|
2024-11-02 10:38:26 +01:00
|
|
|
|
|
|
|
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
|
|
|
|
GEMINI_EVAL_MODEL = os.getenv("GEMINI_EVAL_MODEL", "gemini-1.5-pro-002")
|
|
|
|
GEMINI_API_URL = (
|
|
|
|
f"https://generativelanguage.googleapis.com/v1beta/models/{GEMINI_EVAL_MODEL}:generateContent?key={GEMINI_API_KEY}"
|
|
|
|
)
|
|
|
|
|
|
|
|
SAMPLE_SIZE = os.getenv("SAMPLE_SIZE") # Number of examples to evaluate
|
|
|
|
RANDOMIZE = os.getenv("RANDOMIZE", "false").lower() == "true" # Randomize examples
|
2024-11-18 11:19:30 +01:00
|
|
|
BATCH_SIZE = int(
|
|
|
|
os.getenv("BATCH_SIZE", int(SAMPLE_SIZE) / 10 if SAMPLE_SIZE else 10)
|
|
|
|
) # Examples to evaluate in each batch
|
|
|
|
SLEEP_SECONDS = 3 if KHOJ_MODE == "general" else 1 # Sleep between API calls to avoid rate limiting
|
2024-11-02 10:38:26 +01:00
|
|
|
|
|
|
|
|
|
|
|
def load_frames_dataset():
|
2024-11-15 00:55:00 +01:00
|
|
|
"""
|
|
|
|
Load the Google FRAMES benchmark dataset from HuggingFace
|
|
|
|
|
|
|
|
FRAMES is a benchmark dataset to evaluate retrieval and answering capabilities of agents.
|
|
|
|
It contains ~800 requiring multi-hop retrieval and reasoning across various topics.
|
|
|
|
|
|
|
|
### Data Fields
|
|
|
|
- Prompt: The question to be answered
|
|
|
|
- Answer: The ground truth answer
|
|
|
|
- reasoning_types: The type of reasoning required to answer the question
|
|
|
|
"""
|
2024-11-02 10:38:26 +01:00
|
|
|
try:
|
|
|
|
dataset = load_dataset("google/frames-benchmark")
|
|
|
|
# Use test split for evaluation. Sample and shuffle dataset if configured
|
2024-11-15 00:55:00 +01:00
|
|
|
dataset = dataset.shuffle() if RANDOMIZE else dataset
|
2024-11-02 10:38:26 +01:00
|
|
|
return dataset["test"][: int(SAMPLE_SIZE)] if SAMPLE_SIZE else dataset["test"]
|
|
|
|
|
|
|
|
except Exception as e:
|
2024-11-03 02:20:42 +01:00
|
|
|
logger.error(f"Error loading dataset: {e}")
|
2024-11-02 10:38:26 +01:00
|
|
|
return None
|
|
|
|
|
|
|
|
|
2024-11-15 00:55:00 +01:00
|
|
|
def load_simpleqa_dataset():
|
|
|
|
"""
|
|
|
|
Load the OpenAI SimpleQA benchmark dataset from their public bucket.
|
|
|
|
|
|
|
|
SimpleQA is a dataset of moderately difficult q&a for 2024 models to answer across various topics.
|
|
|
|
It contains ~4000 human vetted questions and answers with additional metadata.
|
|
|
|
Its usage can be seen in openai/simple-evals github repository as well.
|
|
|
|
|
|
|
|
### Data Fields
|
|
|
|
- problem: The question to be answered
|
|
|
|
- answer: The ground truth answer
|
|
|
|
- metadata: Additional metadata including topic information
|
|
|
|
"""
|
|
|
|
|
|
|
|
try:
|
|
|
|
# Load SimpleQA benchmark from OpenAI public bucket
|
|
|
|
raw_url = "https://openaipublic.blob.core.windows.net/simple-evals/simple_qa_test_set.csv"
|
|
|
|
response = requests.get(raw_url)
|
|
|
|
response.raise_for_status()
|
|
|
|
|
|
|
|
# Parse benchmark from raw CSV response
|
|
|
|
csv_data = pd.read_csv(StringIO(response.text))
|
|
|
|
# Normalize it into FRAMES format
|
|
|
|
formatted_data = [
|
|
|
|
{
|
|
|
|
"Prompt": d["problem"],
|
|
|
|
"Answer": d["answer"],
|
|
|
|
"reasoning_types": json.loads(csv_data.to_dict("records")[0]["metadata"].replace("'", '"'))["topic"],
|
|
|
|
}
|
|
|
|
for d in csv_data.to_dict("records")
|
|
|
|
]
|
|
|
|
|
|
|
|
# Convert benchmark to HF Dataset
|
|
|
|
dataset = Dataset.from_list(formatted_data)
|
|
|
|
dataset = dataset.shuffle() if RANDOMIZE else dataset
|
|
|
|
dataset = dataset.select(range(int(SAMPLE_SIZE))) if SAMPLE_SIZE else dataset
|
|
|
|
|
|
|
|
return dataset
|
|
|
|
except Exception as e:
|
|
|
|
logger.error(f"Error loading simpleqa dataset: {e}")
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
2024-11-02 10:38:26 +01:00
|
|
|
def get_agent_response(prompt: str) -> str:
|
|
|
|
"""Get response from the Khoj API"""
|
|
|
|
try:
|
|
|
|
response = requests.post(
|
|
|
|
KHOJ_CHAT_API_URL,
|
|
|
|
headers={"Content-Type": "application/json", "Authorization": f"Bearer {KHOJ_API_KEY}"},
|
2024-11-03 02:20:42 +01:00
|
|
|
json={
|
|
|
|
"q": prompt,
|
|
|
|
"create_new": True,
|
|
|
|
},
|
2024-11-02 10:38:26 +01:00
|
|
|
)
|
|
|
|
response.raise_for_status()
|
|
|
|
return response.json().get("response", "")
|
|
|
|
except Exception as e:
|
2024-11-03 02:20:42 +01:00
|
|
|
logger.error(f"Error getting agent response: {e}")
|
2024-11-02 10:38:26 +01:00
|
|
|
return ""
|
|
|
|
|
|
|
|
|
|
|
|
def evaluate_response(query: str, agent_response: str, ground_truth: str) -> Dict[str, Any]:
|
|
|
|
"""Evaluate Khoj response against benchmark ground truth using Gemini"""
|
|
|
|
evaluation_prompt = f"""
|
|
|
|
Compare the following agent response with the ground truth answer.
|
|
|
|
Determine if the agent response contains the key information from the ground truth.
|
|
|
|
Focus on factual correctness rather than exact wording.
|
|
|
|
|
|
|
|
Query: {query}
|
|
|
|
Agent Response: {agent_response}
|
|
|
|
Ground Truth: {ground_truth}
|
|
|
|
|
|
|
|
Provide your evaluation in the following json format:
|
|
|
|
{"explanation:" "[How you made the decision?)", "decision:" "(TRUE if response contains key information, FALSE otherwise)"}
|
|
|
|
"""
|
|
|
|
|
|
|
|
try:
|
|
|
|
response = requests.post(
|
|
|
|
GEMINI_API_URL,
|
2024-11-13 12:13:36 +01:00
|
|
|
headers={"Content-Type": "application/json"},
|
|
|
|
json={
|
|
|
|
"contents": [{"parts": [{"text": evaluation_prompt}]}],
|
|
|
|
"generationConfig": {"response_mime_type": "application/json"},
|
|
|
|
},
|
2024-11-02 10:38:26 +01:00
|
|
|
)
|
|
|
|
response.raise_for_status()
|
|
|
|
|
|
|
|
# Parse evaluation response
|
2024-11-13 12:13:36 +01:00
|
|
|
eval_response: dict[str, str] = json.loads(
|
|
|
|
clean_json(response.json()["candidates"][0]["content"]["parts"][0]["text"])
|
|
|
|
)
|
|
|
|
decision = str(eval_response.get("decision", "")).upper() == "TRUE"
|
|
|
|
explanation = eval_response.get("explanation", "")
|
|
|
|
# Handle evaluation service errors
|
|
|
|
if "503 Service Error" in explanation:
|
|
|
|
decision = None
|
2024-11-02 10:38:26 +01:00
|
|
|
# Extract decision and explanation from structured response
|
2024-11-13 12:13:36 +01:00
|
|
|
return decision, explanation
|
2024-11-02 10:38:26 +01:00
|
|
|
except Exception as e:
|
2024-11-03 02:20:42 +01:00
|
|
|
logger.error(f"Error in evaluation: {e}")
|
2024-11-13 12:13:36 +01:00
|
|
|
return None, f"Evaluation failed: {str(e)}"
|
2024-11-02 10:38:26 +01:00
|
|
|
|
|
|
|
|
2024-11-09 00:46:44 +01:00
|
|
|
def process_batch(batch, batch_start, results, dataset_length):
|
|
|
|
for idx, (prompt, answer, reasoning_type) in enumerate(batch):
|
|
|
|
current_index = batch_start + idx
|
|
|
|
logger.info(f"Processing example: {current_index}/{dataset_length}")
|
2024-11-02 12:58:03 +01:00
|
|
|
|
|
|
|
# Trigger research mode if enabled
|
2024-11-13 12:13:36 +01:00
|
|
|
prompt = f"/{KHOJ_MODE} {prompt}" if KHOJ_MODE and not prompt.startswith(f"/{KHOJ_MODE}") else prompt
|
2024-11-02 12:58:03 +01:00
|
|
|
|
|
|
|
# Get agent response
|
|
|
|
agent_response = get_agent_response(prompt)
|
|
|
|
|
|
|
|
# Evaluate response
|
2024-11-13 12:13:36 +01:00
|
|
|
if is_none_or_empty(agent_response):
|
|
|
|
decision = None
|
|
|
|
explanation = "Agent response is empty. This maybe due to a service error."
|
2024-11-08 00:23:30 +01:00
|
|
|
else:
|
2024-11-13 12:13:36 +01:00
|
|
|
decision, explanation = evaluate_response(prompt, agent_response, answer)
|
2024-11-02 12:58:03 +01:00
|
|
|
|
|
|
|
# Store results
|
|
|
|
results.append(
|
|
|
|
{
|
2024-11-09 00:46:44 +01:00
|
|
|
"index": current_index,
|
2024-11-02 12:58:03 +01:00
|
|
|
"prompt": prompt,
|
|
|
|
"ground_truth": answer,
|
|
|
|
"agent_response": agent_response,
|
2024-11-13 12:13:36 +01:00
|
|
|
"evaluation_decision": decision,
|
|
|
|
"evaluation_explanation": explanation,
|
2024-11-02 12:58:03 +01:00
|
|
|
"reasoning_type": reasoning_type,
|
|
|
|
}
|
|
|
|
)
|
|
|
|
|
2024-11-13 12:13:36 +01:00
|
|
|
# Log results
|
|
|
|
decision_color = {True: "green", None: "blue", False: "red"}[decision]
|
2024-11-03 02:20:42 +01:00
|
|
|
colored_decision = color_text(str(decision), decision_color)
|
|
|
|
logger.info(
|
2024-11-13 12:13:36 +01:00
|
|
|
f"Decision: {colored_decision}\nQuestion: {prompt}\nExpected Answer: {answer}\nAgent Answer: {agent_response}\nExplanation: {explanation}\n"
|
2024-11-02 12:58:03 +01:00
|
|
|
)
|
|
|
|
|
|
|
|
time.sleep(SLEEP_SECONDS) # Rate limiting
|
|
|
|
|
|
|
|
|
2024-11-02 10:38:26 +01:00
|
|
|
def color_text(text, color):
|
2024-11-13 12:13:36 +01:00
|
|
|
colors = {
|
|
|
|
"red": "\033[91m", # Bright red
|
|
|
|
"green": "\033[32m", # Standard green
|
2024-11-14 05:02:56 +01:00
|
|
|
"blue": "\033[34m", # Bright blue
|
2024-11-13 12:13:36 +01:00
|
|
|
"reset": "\033[0m",
|
|
|
|
}
|
2024-11-02 10:38:26 +01:00
|
|
|
return f"{colors[color]}{text}{colors['reset']}"
|
|
|
|
|
|
|
|
|
|
|
|
def clean_json(response: str):
|
|
|
|
"""Remove any markdown json codeblock and newline formatting if present. Useful for non schema enforceable models"""
|
|
|
|
return response.strip().replace("\n", "").removeprefix("```json").removesuffix("```")
|
|
|
|
|
|
|
|
|
2024-11-13 12:13:36 +01:00
|
|
|
def parse_args():
|
2024-11-14 05:02:56 +01:00
|
|
|
parser = argparse.ArgumentParser(description="Evaluate Khoj on a supported benchmark.")
|
2024-11-13 12:13:36 +01:00
|
|
|
parser.add_argument(
|
|
|
|
"--output",
|
|
|
|
"-o",
|
|
|
|
default=None,
|
2024-11-14 05:02:56 +01:00
|
|
|
help="Path to store evaluation results CSV (default: [benchmark]_evaluation_results_[datetime].csv)",
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--dataset",
|
|
|
|
"-d",
|
|
|
|
default="frames",
|
2024-11-15 00:55:00 +01:00
|
|
|
choices=["frames", "simpleqa"],
|
2024-11-14 05:02:56 +01:00
|
|
|
help="Dataset to use for evaluation (default: frames)",
|
2024-11-13 12:13:36 +01:00
|
|
|
)
|
|
|
|
return parser.parse_args()
|
|
|
|
|
|
|
|
|
2024-11-02 10:38:26 +01:00
|
|
|
def main():
|
2024-11-13 12:13:36 +01:00
|
|
|
# Initialize variables
|
|
|
|
args = parse_args()
|
2024-11-14 05:02:56 +01:00
|
|
|
dataset = None
|
2024-11-13 12:13:36 +01:00
|
|
|
|
2024-11-02 10:38:26 +01:00
|
|
|
# Load dataset
|
2024-11-15 00:55:00 +01:00
|
|
|
with timer(f"Loaded {args.dataset} dataset in", logger, log_level=logging.INFO):
|
2024-11-14 05:02:56 +01:00
|
|
|
if args.dataset == "frames":
|
|
|
|
dataset = load_frames_dataset()
|
2024-11-15 00:55:00 +01:00
|
|
|
elif args.dataset == "simpleqa":
|
|
|
|
dataset = load_simpleqa_dataset()
|
2024-11-02 10:38:26 +01:00
|
|
|
if dataset is None:
|
|
|
|
return
|
|
|
|
|
|
|
|
# Initialize variables
|
|
|
|
results = []
|
2024-11-02 12:58:03 +01:00
|
|
|
dataset_length = len(dataset["Prompt"])
|
2024-11-02 10:38:26 +01:00
|
|
|
|
|
|
|
# Process examples in batches
|
2024-11-02 12:58:03 +01:00
|
|
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
|
|
|
futures = []
|
|
|
|
for i in range(0, dataset_length, BATCH_SIZE):
|
2024-11-09 00:46:44 +01:00
|
|
|
batch_start = i
|
2024-11-02 12:58:03 +01:00
|
|
|
batch = zip(
|
|
|
|
dataset["Prompt"][i : i + BATCH_SIZE],
|
|
|
|
dataset["Answer"][i : i + BATCH_SIZE],
|
|
|
|
dataset["reasoning_types"][i : i + BATCH_SIZE],
|
2024-11-02 10:38:26 +01:00
|
|
|
)
|
2024-11-09 00:46:44 +01:00
|
|
|
futures.append(executor.submit(process_batch, batch, batch_start, results, dataset_length))
|
2024-11-02 10:38:26 +01:00
|
|
|
|
2024-11-02 12:58:03 +01:00
|
|
|
# Wait for all futures to complete
|
|
|
|
concurrent.futures.wait(futures)
|
2024-11-02 10:38:26 +01:00
|
|
|
|
|
|
|
# Calculate metrics
|
|
|
|
df = pd.DataFrame(results)
|
2024-11-13 12:13:36 +01:00
|
|
|
eval_df = df.dropna(subset=["evaluation_decision"]) # Exclude rows with missing evaluation decision
|
|
|
|
accuracy = (eval_df["evaluation_decision"] == True).mean()
|
2024-11-02 10:38:26 +01:00
|
|
|
|
|
|
|
# Calculate accuracy by reasoning type
|
2024-11-13 12:13:36 +01:00
|
|
|
reasoning_type_accuracy = eval_df.groupby("reasoning_type")["evaluation_decision"].apply(
|
|
|
|
lambda x: (x == True).mean()
|
|
|
|
)
|
2024-11-02 10:38:26 +01:00
|
|
|
|
|
|
|
# Print summary
|
2024-11-14 05:02:56 +01:00
|
|
|
colored_accuracy = color_text(f"{accuracy:.2%}", "blue")
|
|
|
|
logger.info(f"\nOverall Accuracy: {colored_accuracy}")
|
2024-11-13 12:13:36 +01:00
|
|
|
logger.info(f"\nAccuracy by Reasoning Type:\n{reasoning_type_accuracy}")
|
|
|
|
|
2024-11-18 11:26:25 +01:00
|
|
|
# Save summary to file
|
|
|
|
sample_type = f"Sampling Type: {SAMPLE_SIZE} samples." if SAMPLE_SIZE else "Whole dataset."
|
|
|
|
sample_type += " Randomized." if RANDOMIZE else ""
|
|
|
|
summary = (
|
|
|
|
f"Overall Accuracy: {accuracy:.2%}\n\nAccuracy by Reasoning Type:\n{reasoning_type_accuracy}\n\n{sample_type}\n"
|
|
|
|
)
|
|
|
|
summary_file = args.output.replace(".csv", ".txt") if args.output else None
|
|
|
|
summary_file = (
|
|
|
|
summary_file or f"{args.dataset}_evaluation_summary_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.txt"
|
|
|
|
)
|
|
|
|
with open(summary_file, "w") as f:
|
|
|
|
f.write(summary)
|
|
|
|
|
|
|
|
# Save raw results to file
|
2024-11-14 05:02:56 +01:00
|
|
|
output_file = args.output or f"{args.dataset}_evaluation_results_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.csv"
|
2024-11-13 12:13:36 +01:00
|
|
|
df.to_csv(output_file, index=False)
|
2024-11-18 11:26:25 +01:00
|
|
|
logger.info(f"Results saved to {summary_file}, {output_file}")
|
2024-11-02 10:38:26 +01:00
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
"""
|
2024-11-14 05:02:56 +01:00
|
|
|
Evaluate Khoj on supported benchmarks.
|
2024-11-02 10:38:26 +01:00
|
|
|
Response are evaluated by GEMINI_EVAL_MODEL (default: gemini-pro-1.5-002).
|
|
|
|
|
2024-11-13 12:13:36 +01:00
|
|
|
Khoj should be running at KHOJ_URL (default: http://localhost:42110).
|
2024-11-02 10:38:26 +01:00
|
|
|
The Gemini judge model is accessed via the Gemini API with your GEMINI_API_KEY.
|
|
|
|
To evaluate Khoj in research mode, set the KHOJ_MODE environment variable to "research".
|
|
|
|
|
|
|
|
Run the script using the following command:
|
|
|
|
KHOJ_MODE="research" GEMINI_API_KEY="<your_gemini_api_key>" python eval_frames.py
|
|
|
|
"""
|
2024-11-14 05:02:56 +01:00
|
|
|
logger.info(f"{datetime.now()} - Begin Quizzing Khoj.")
|
2024-11-13 12:13:36 +01:00
|
|
|
with timer("Ran eval script in", logger, log_level=logging.INFO):
|
2024-11-03 02:20:42 +01:00
|
|
|
main()
|
2024-11-14 05:02:56 +01:00
|
|
|
logger.info(f"{datetime.now()} - End Quizzing Khoj.")
|