mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 23:48:56 +01:00
Run prompt batches in parallel for faster eval runs
This commit is contained in:
parent
96904e0769
commit
791eb205f6
1 changed files with 49 additions and 39 deletions
|
@ -1,3 +1,4 @@
|
|||
import concurrent.futures
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
|
@ -90,6 +91,43 @@ def evaluate_response(query: str, agent_response: str, ground_truth: str) -> Dic
|
|||
return {"decision": "FALSE", "explanation": f"Evaluation failed: {str(e)}"}
|
||||
|
||||
|
||||
def process_batch(batch, counter, results, dataset_length):
|
||||
for prompt, answer, reasoning_type in batch:
|
||||
counter += 1
|
||||
print(f"Processing example: {counter}/{dataset_length}")
|
||||
|
||||
# Trigger research mode if enabled
|
||||
prompt = f"/{KHOJ_MODE} {prompt}" if KHOJ_MODE else prompt
|
||||
|
||||
# Get agent response
|
||||
agent_response = get_agent_response(prompt)
|
||||
|
||||
# Evaluate response
|
||||
evaluation = evaluate_response(prompt, agent_response, answer)
|
||||
|
||||
# Store results
|
||||
results.append(
|
||||
{
|
||||
"index": counter,
|
||||
"prompt": prompt,
|
||||
"ground_truth": answer,
|
||||
"agent_response": agent_response,
|
||||
"evaluation_decision": evaluation["decision"],
|
||||
"evaluation_explanation": evaluation["explanation"],
|
||||
"reasoning_type": reasoning_type,
|
||||
}
|
||||
)
|
||||
|
||||
# Color the decision based on its value
|
||||
decision_color = "green" if evaluation["decision"] == True else "red"
|
||||
colored_decision = color_text(evaluation["decision"], decision_color)
|
||||
print(
|
||||
f'Decision: {colored_decision}\nQuestion: {prompt}\nExpected Answer: {answer}\nAgent Answer: {agent_response}\nExplanation: {evaluation["explanation"]}\n'
|
||||
)
|
||||
|
||||
time.sleep(SLEEP_SECONDS) # Rate limiting
|
||||
|
||||
|
||||
def color_text(text, color):
|
||||
colors = {"red": "\033[91m", "green": "\033[92m", "reset": "\033[0m"}
|
||||
return f"{colors[color]}{text}{colors['reset']}"
|
||||
|
@ -109,49 +147,21 @@ def main():
|
|||
# Initialize variables
|
||||
counter = 0
|
||||
results = []
|
||||
dataset_length = len(dataset["Prompt"])
|
||||
|
||||
# Process examples in batches
|
||||
for i in range(0, len(dataset), BATCH_SIZE):
|
||||
batch = zip(
|
||||
dataset["Prompt"][i : i + BATCH_SIZE],
|
||||
dataset["Answer"][i : i + BATCH_SIZE],
|
||||
dataset["reasoning_types"][i : i + BATCH_SIZE],
|
||||
)
|
||||
|
||||
for prompt, answer, reasoning_type in batch:
|
||||
counter += 1
|
||||
print(f'Processing example: {counter}/{len(dataset["Prompt"])}')
|
||||
|
||||
# Trigger research mode if enabled
|
||||
prompt = f"/{KHOJ_MODE} {prompt}" if KHOJ_MODE else prompt
|
||||
|
||||
# Get agent response
|
||||
agent_response = get_agent_response(prompt)
|
||||
|
||||
# Evaluate response
|
||||
evaluation = evaluate_response(agent_response, answer)
|
||||
|
||||
# Store results
|
||||
results.append(
|
||||
{
|
||||
"index": i,
|
||||
"prompt": prompt,
|
||||
"ground_truth": answer,
|
||||
"agent_response": agent_response,
|
||||
"evaluation_decision": evaluation["decision"],
|
||||
"evaluation_explanation": evaluation["explanation"],
|
||||
"reasoning_type": reasoning_type,
|
||||
}
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
futures = []
|
||||
for i in range(0, dataset_length, BATCH_SIZE):
|
||||
batch = zip(
|
||||
dataset["Prompt"][i : i + BATCH_SIZE],
|
||||
dataset["Answer"][i : i + BATCH_SIZE],
|
||||
dataset["reasoning_types"][i : i + BATCH_SIZE],
|
||||
)
|
||||
futures.append(executor.submit(process_batch, batch, counter, results, dataset_length))
|
||||
|
||||
# Color the decision based on its value
|
||||
decision_color = "green" if evaluation["decision"] == True else "red"
|
||||
colored_decision = color_text(evaluation["decision"], decision_color)
|
||||
print(
|
||||
f'Decision: {colored_decision}\nQuestion: {prompt}\nExpected Answer: {answer}\nAgent Answer: {agent_response}\nExplanation: {evaluation["explanation"]}\n'
|
||||
)
|
||||
|
||||
time.sleep(SLEEP_SECONDS) # Rate limiting
|
||||
# Wait for all futures to complete
|
||||
concurrent.futures.wait(futures)
|
||||
|
||||
# Calculate metrics
|
||||
df = pd.DataFrame(results)
|
||||
|
|
Loading…
Reference in a new issue