From a62ff614fb8f681b5d33767e5e010767f7e5b91c Mon Sep 17 00:00:00 2001 From: Debanjum Date: Fri, 8 Nov 2024 15:46:44 -0800 Subject: [PATCH] Show correct example index being currently processed in frames eval Previously the batch start index wasn't being passed so all batches started in parallel were showing the same processing example index This change doesn't impact the evaluation itself, just the index shown of the example currently being evaluated --- tests/eval_frames.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/tests/eval_frames.py b/tests/eval_frames.py index ae44f63e..102da469 100644 --- a/tests/eval_frames.py +++ b/tests/eval_frames.py @@ -101,10 +101,10 @@ def evaluate_response(query: str, agent_response: str, ground_truth: str) -> Dic return {"decision": "FALSE", "explanation": f"Evaluation failed: {str(e)}"} -def process_batch(batch, counter, results, dataset_length): - for prompt, answer, reasoning_type in batch: - counter += 1 - logger.info(f"Processing example: {counter}/{dataset_length}") +def process_batch(batch, batch_start, results, dataset_length): + for idx, (prompt, answer, reasoning_type) in enumerate(batch): + current_index = batch_start + idx + logger.info(f"Processing example: {current_index}/{dataset_length}") # Trigger research mode if enabled prompt = f"/{KHOJ_MODE} {prompt}" if KHOJ_MODE else prompt @@ -122,7 +122,7 @@ def process_batch(batch, counter, results, dataset_length): # Store results results.append( { - "index": counter, + "index": current_index, "prompt": prompt, "ground_truth": answer, "agent_response": agent_response, @@ -169,12 +169,13 @@ def main(): with concurrent.futures.ThreadPoolExecutor() as executor: futures = [] for i in range(0, dataset_length, BATCH_SIZE): + batch_start = i batch = zip( dataset["Prompt"][i : i + BATCH_SIZE], dataset["Answer"][i : i + BATCH_SIZE], dataset["reasoning_types"][i : i + BATCH_SIZE], ) - futures.append(executor.submit(process_batch, batch, counter, results, dataset_length)) + futures.append(executor.submit(process_batch, batch, batch_start, results, dataset_length)) # Wait for all futures to complete concurrent.futures.wait(futures)