mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-27 17:35:07 +01:00
Use logger instead of print to track eval
This commit is contained in:
parent
791eb205f6
commit
1ccbf72752
1 changed files with 26 additions and 13 deletions
|
@ -1,5 +1,6 @@
|
||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
|
@ -8,6 +9,12 @@ import pandas as pd
|
||||||
import requests
|
import requests
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
|
|
||||||
|
from khoj.utils.helpers import timer
|
||||||
|
|
||||||
|
# Configure root logger
|
||||||
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(message)s")
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Configuration
|
# Configuration
|
||||||
KHOJ_URL = os.getenv("KHOJ_URL", "http://localhost:42110")
|
KHOJ_URL = os.getenv("KHOJ_URL", "http://localhost:42110")
|
||||||
KHOJ_CHAT_API_URL = f"{KHOJ_URL}/api/chat"
|
KHOJ_CHAT_API_URL = f"{KHOJ_URL}/api/chat"
|
||||||
|
@ -35,7 +42,7 @@ def load_frames_dataset():
|
||||||
return dataset["test"][: int(SAMPLE_SIZE)] if SAMPLE_SIZE else dataset["test"]
|
return dataset["test"][: int(SAMPLE_SIZE)] if SAMPLE_SIZE else dataset["test"]
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error loading dataset: {e}")
|
logger.error(f"Error loading dataset: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@ -45,12 +52,15 @@ def get_agent_response(prompt: str) -> str:
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
KHOJ_CHAT_API_URL,
|
KHOJ_CHAT_API_URL,
|
||||||
headers={"Content-Type": "application/json", "Authorization": f"Bearer {KHOJ_API_KEY}"},
|
headers={"Content-Type": "application/json", "Authorization": f"Bearer {KHOJ_API_KEY}"},
|
||||||
json={"q": prompt},
|
json={
|
||||||
|
"q": prompt,
|
||||||
|
"create_new": True,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
return response.json().get("response", "")
|
return response.json().get("response", "")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error getting agent response: {e}")
|
logger.error(f"Error getting agent response: {e}")
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
@ -87,14 +97,14 @@ def evaluate_response(query: str, agent_response: str, ground_truth: str) -> Dic
|
||||||
"explanation": eval_response.get("explanation", ""),
|
"explanation": eval_response.get("explanation", ""),
|
||||||
}
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error in evaluation: {e}")
|
logger.error(f"Error in evaluation: {e}")
|
||||||
return {"decision": "FALSE", "explanation": f"Evaluation failed: {str(e)}"}
|
return {"decision": "FALSE", "explanation": f"Evaluation failed: {str(e)}"}
|
||||||
|
|
||||||
|
|
||||||
def process_batch(batch, counter, results, dataset_length):
|
def process_batch(batch, counter, results, dataset_length):
|
||||||
for prompt, answer, reasoning_type in batch:
|
for prompt, answer, reasoning_type in batch:
|
||||||
counter += 1
|
counter += 1
|
||||||
print(f"Processing example: {counter}/{dataset_length}")
|
logger.info(f"Processing example: {counter}/{dataset_length}")
|
||||||
|
|
||||||
# Trigger research mode if enabled
|
# Trigger research mode if enabled
|
||||||
prompt = f"/{KHOJ_MODE} {prompt}" if KHOJ_MODE else prompt
|
prompt = f"/{KHOJ_MODE} {prompt}" if KHOJ_MODE else prompt
|
||||||
|
@ -119,9 +129,10 @@ def process_batch(batch, counter, results, dataset_length):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Color the decision based on its value
|
# Color the decision based on its value
|
||||||
decision_color = "green" if evaluation["decision"] == True else "red"
|
decision = evaluation["decision"]
|
||||||
colored_decision = color_text(evaluation["decision"], decision_color)
|
decision_color = "green" if decision == True else "red"
|
||||||
print(
|
colored_decision = color_text(str(decision), decision_color)
|
||||||
|
logger.info(
|
||||||
f'Decision: {colored_decision}\nQuestion: {prompt}\nExpected Answer: {answer}\nAgent Answer: {agent_response}\nExplanation: {evaluation["explanation"]}\n'
|
f'Decision: {colored_decision}\nQuestion: {prompt}\nExpected Answer: {answer}\nAgent Answer: {agent_response}\nExplanation: {evaluation["explanation"]}\n'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -140,6 +151,7 @@ def clean_json(response: str):
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
# Load dataset
|
# Load dataset
|
||||||
|
with timer("Loaded dataset in", logger):
|
||||||
dataset = load_frames_dataset()
|
dataset = load_frames_dataset()
|
||||||
if dataset is None:
|
if dataset is None:
|
||||||
return
|
return
|
||||||
|
@ -174,9 +186,9 @@ def main():
|
||||||
df.to_csv("frames_evaluation_results.csv", index=False)
|
df.to_csv("frames_evaluation_results.csv", index=False)
|
||||||
|
|
||||||
# Print summary
|
# Print summary
|
||||||
print(f"\nOverall Accuracy: {accuracy:.2%}")
|
logger.info(f"\nOverall Accuracy: {accuracy:.2%}")
|
||||||
print("\nAccuracy by Reasoning Type:")
|
logger.info("\nAccuracy by Reasoning Type:")
|
||||||
print(reasoning_type_accuracy)
|
logger.info(reasoning_type_accuracy)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -191,4 +203,5 @@ if __name__ == "__main__":
|
||||||
Run the script using the following command:
|
Run the script using the following command:
|
||||||
KHOJ_MODE="research" GEMINI_API_KEY="<your_gemini_api_key>" python eval_frames.py
|
KHOJ_MODE="research" GEMINI_API_KEY="<your_gemini_api_key>" python eval_frames.py
|
||||||
"""
|
"""
|
||||||
|
with timer("Ran eval in", logger):
|
||||||
main()
|
main()
|
||||||
|
|
Loading…
Reference in a new issue