Use logger instead of print to track eval

This commit is contained in:
Debanjum 2024-11-02 18:20:42 -07:00
parent 791eb205f6
commit 1ccbf72752

View file

@ -1,5 +1,6 @@
import concurrent.futures import concurrent.futures
import json import json
import logging
import os import os
import time import time
from typing import Any, Dict from typing import Any, Dict
@ -8,6 +9,12 @@ import pandas as pd
import requests import requests
from datasets import load_dataset from datasets import load_dataset
from khoj.utils.helpers import timer
# Configure root logger
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(message)s")
logger = logging.getLogger(__name__)
# Configuration # Configuration
KHOJ_URL = os.getenv("KHOJ_URL", "http://localhost:42110") KHOJ_URL = os.getenv("KHOJ_URL", "http://localhost:42110")
KHOJ_CHAT_API_URL = f"{KHOJ_URL}/api/chat" KHOJ_CHAT_API_URL = f"{KHOJ_URL}/api/chat"
@ -35,7 +42,7 @@ def load_frames_dataset():
return dataset["test"][: int(SAMPLE_SIZE)] if SAMPLE_SIZE else dataset["test"] return dataset["test"][: int(SAMPLE_SIZE)] if SAMPLE_SIZE else dataset["test"]
except Exception as e: except Exception as e:
print(f"Error loading dataset: {e}") logger.error(f"Error loading dataset: {e}")
return None return None
@ -45,12 +52,15 @@ def get_agent_response(prompt: str) -> str:
response = requests.post( response = requests.post(
KHOJ_CHAT_API_URL, KHOJ_CHAT_API_URL,
headers={"Content-Type": "application/json", "Authorization": f"Bearer {KHOJ_API_KEY}"}, headers={"Content-Type": "application/json", "Authorization": f"Bearer {KHOJ_API_KEY}"},
json={"q": prompt}, json={
"q": prompt,
"create_new": True,
},
) )
response.raise_for_status() response.raise_for_status()
return response.json().get("response", "") return response.json().get("response", "")
except Exception as e: except Exception as e:
print(f"Error getting agent response: {e}") logger.error(f"Error getting agent response: {e}")
return "" return ""
@ -87,14 +97,14 @@ def evaluate_response(query: str, agent_response: str, ground_truth: str) -> Dic
"explanation": eval_response.get("explanation", ""), "explanation": eval_response.get("explanation", ""),
} }
except Exception as e: except Exception as e:
print(f"Error in evaluation: {e}") logger.error(f"Error in evaluation: {e}")
return {"decision": "FALSE", "explanation": f"Evaluation failed: {str(e)}"} return {"decision": "FALSE", "explanation": f"Evaluation failed: {str(e)}"}
def process_batch(batch, counter, results, dataset_length): def process_batch(batch, counter, results, dataset_length):
for prompt, answer, reasoning_type in batch: for prompt, answer, reasoning_type in batch:
counter += 1 counter += 1
print(f"Processing example: {counter}/{dataset_length}") logger.info(f"Processing example: {counter}/{dataset_length}")
# Trigger research mode if enabled # Trigger research mode if enabled
prompt = f"/{KHOJ_MODE} {prompt}" if KHOJ_MODE else prompt prompt = f"/{KHOJ_MODE} {prompt}" if KHOJ_MODE else prompt
@ -119,9 +129,10 @@ def process_batch(batch, counter, results, dataset_length):
) )
# Color the decision based on its value # Color the decision based on its value
decision_color = "green" if evaluation["decision"] == True else "red" decision = evaluation["decision"]
colored_decision = color_text(evaluation["decision"], decision_color) decision_color = "green" if decision == True else "red"
print( colored_decision = color_text(str(decision), decision_color)
logger.info(
f'Decision: {colored_decision}\nQuestion: {prompt}\nExpected Answer: {answer}\nAgent Answer: {agent_response}\nExplanation: {evaluation["explanation"]}\n' f'Decision: {colored_decision}\nQuestion: {prompt}\nExpected Answer: {answer}\nAgent Answer: {agent_response}\nExplanation: {evaluation["explanation"]}\n'
) )
@ -140,6 +151,7 @@ def clean_json(response: str):
def main(): def main():
# Load dataset # Load dataset
with timer("Loaded dataset in", logger):
dataset = load_frames_dataset() dataset = load_frames_dataset()
if dataset is None: if dataset is None:
return return
@ -174,9 +186,9 @@ def main():
df.to_csv("frames_evaluation_results.csv", index=False) df.to_csv("frames_evaluation_results.csv", index=False)
# Print summary # Print summary
print(f"\nOverall Accuracy: {accuracy:.2%}") logger.info(f"\nOverall Accuracy: {accuracy:.2%}")
print("\nAccuracy by Reasoning Type:") logger.info("\nAccuracy by Reasoning Type:")
print(reasoning_type_accuracy) logger.info(reasoning_type_accuracy)
if __name__ == "__main__": if __name__ == "__main__":
@ -191,4 +203,5 @@ if __name__ == "__main__":
Run the script using the following command: Run the script using the following command:
KHOJ_MODE="research" GEMINI_API_KEY="<your_gemini_api_key>" python eval_frames.py KHOJ_MODE="research" GEMINI_API_KEY="<your_gemini_api_key>" python eval_frames.py
""" """
with timer("Ran eval in", logger):
main() main()