diff --git a/.github/workflows/run_evals.yml b/.github/workflows/run_evals.yml index fdca6f12..46c6d11c 100644 --- a/.github/workflows/run_evals.yml +++ b/.github/workflows/run_evals.yml @@ -1,9 +1,10 @@ -name: Run Khoj Evals +name: eval on: - # Run on every releases - release: - types: [published] + # Run on every release + push: + tags: + - "*" # Allow manual triggers from GitHub UI workflow_dispatch: inputs: @@ -82,7 +83,7 @@ jobs: sed -i 's/dynamic = \["version"\]/version = "${{ steps.hatch.outputs.version }}"/' pyproject.toml pip install --upgrade .[dev] - - name: 📝 Run Evals + - name: 📝 Run Eval env: KHOJ_MODE: ${{ matrix.khoj_mode }} SAMPLE_SIZE: ${{ inputs.sample_size }} diff --git a/src/interface/obsidian/src/chat_view.ts b/src/interface/obsidian/src/chat_view.ts index 408ce3a1..552a54bd 100644 --- a/src/interface/obsidian/src/chat_view.ts +++ b/src/interface/obsidian/src/chat_view.ts @@ -945,7 +945,7 @@ export class KhojChatView extends KhojPaneView { console.log("Started streaming", new Date()); } else if (chunk.type === 'end_llm_response') { console.log("Stopped streaming", new Date()); - + } else if (chunk.type === 'end_response') { // Automatically respond with voice if the subscribed user has sent voice message if (this.chatMessageState.isVoice && this.setting.userInfo?.is_active) this.textToSpeech(this.chatMessageState.rawResponse); diff --git a/src/interface/web/app/common/chatFunctions.ts b/src/interface/web/app/common/chatFunctions.ts index aeb74ee8..6585b4c9 100644 --- a/src/interface/web/app/common/chatFunctions.ts +++ b/src/interface/web/app/common/chatFunctions.ts @@ -133,7 +133,7 @@ export function processMessageChunk( console.log(`Started streaming: ${new Date()}`); } else if (chunk.type === "end_llm_response") { console.log(`Completed streaming: ${new Date()}`); - + } else if (chunk.type === "end_response") { // Append any references after all the data has been streamed if (codeContext) currentMessage.codeContext = codeContext; if (onlineContext) currentMessage.onlineContext = onlineContext; diff --git a/src/khoj/processor/conversation/anthropic/utils.py b/src/khoj/processor/conversation/anthropic/utils.py index cdce63c6..4f6b68c2 100644 --- a/src/khoj/processor/conversation/anthropic/utils.py +++ b/src/khoj/processor/conversation/anthropic/utils.py @@ -18,7 +18,7 @@ from khoj.processor.conversation.utils import ( get_image_from_url, ) from khoj.utils import state -from khoj.utils.helpers import in_debug_mode, is_none_or_empty +from khoj.utils.helpers import get_chat_usage_metrics, in_debug_mode, is_none_or_empty logger = logging.getLogger(__name__) @@ -59,6 +59,7 @@ def anthropic_completion_with_backoff( aggregated_response = "{" if response_type == "json_object" else "" max_tokens = max_tokens or DEFAULT_MAX_TOKENS_ANTHROPIC + final_message = None model_kwargs = model_kwargs or dict() if system_prompt: model_kwargs["system"] = system_prompt @@ -73,6 +74,12 @@ def anthropic_completion_with_backoff( ) as stream: for text in stream.text_stream: aggregated_response += text + final_message = stream.get_final_message() + + # Calculate cost of chat + input_tokens = final_message.usage.input_tokens + output_tokens = final_message.usage.output_tokens + tracer["usage"] = get_chat_usage_metrics(model_name, input_tokens, output_tokens, tracer.get("usage")) # Save conversation trace tracer["chat_model"] = model_name @@ -126,6 +133,7 @@ def anthropic_llm_thread( ] aggregated_response = "" + final_message = None with client.messages.stream( messages=formatted_messages, model=model_name, # type: ignore @@ -138,6 +146,12 @@ def anthropic_llm_thread( for text in stream.text_stream: aggregated_response += text g.send(text) + final_message = stream.get_final_message() + + # Calculate cost of chat + input_tokens = final_message.usage.input_tokens + output_tokens = final_message.usage.output_tokens + tracer["usage"] = get_chat_usage_metrics(model_name, input_tokens, output_tokens, tracer.get("usage")) # Save conversation trace tracer["chat_model"] = model_name diff --git a/src/khoj/processor/conversation/google/utils.py b/src/khoj/processor/conversation/google/utils.py index 84ad607e..eb9b21b0 100644 --- a/src/khoj/processor/conversation/google/utils.py +++ b/src/khoj/processor/conversation/google/utils.py @@ -25,7 +25,7 @@ from khoj.processor.conversation.utils import ( get_image_from_url, ) from khoj.utils import state -from khoj.utils.helpers import in_debug_mode, is_none_or_empty +from khoj.utils.helpers import get_chat_usage_metrics, in_debug_mode, is_none_or_empty logger = logging.getLogger(__name__) @@ -68,6 +68,7 @@ def gemini_completion_with_backoff( response = chat_session.send_message(formatted_messages[-1]["parts"]) response_text = response.text except StopCandidateException as e: + response = None response_text, _ = handle_gemini_response(e.args) # Respond with reason for stopping logger.warning( @@ -75,6 +76,11 @@ def gemini_completion_with_backoff( + f"Last Message by {messages[-1].role}: {messages[-1].content}" ) + # Aggregate cost of chat + input_tokens = response.usage_metadata.prompt_token_count if response else 0 + output_tokens = response.usage_metadata.candidates_token_count if response else 0 + tracer["usage"] = get_chat_usage_metrics(model_name, input_tokens, output_tokens, tracer.get("usage")) + # Save conversation trace tracer["chat_model"] = model_name tracer["temperature"] = temperature @@ -146,6 +152,11 @@ def gemini_llm_thread( if stopped: raise StopCandidateException(message) + # Calculate cost of chat + input_tokens = chunk.usage_metadata.prompt_token_count + output_tokens = chunk.usage_metadata.candidates_token_count + tracer["usage"] = get_chat_usage_metrics(model_name, input_tokens, output_tokens, tracer.get("usage")) + # Save conversation trace tracer["chat_model"] = model_name tracer["temperature"] = temperature diff --git a/src/khoj/processor/conversation/openai/utils.py b/src/khoj/processor/conversation/openai/utils.py index a8a94a2e..ddc59d76 100644 --- a/src/khoj/processor/conversation/openai/utils.py +++ b/src/khoj/processor/conversation/openai/utils.py @@ -4,6 +4,8 @@ from threading import Thread from typing import Dict import openai +from openai.types.chat.chat_completion import ChatCompletion +from openai.types.chat.chat_completion_chunk import ChatCompletionChunk from tenacity import ( before_sleep_log, retry, @@ -18,7 +20,7 @@ from khoj.processor.conversation.utils import ( commit_conversation_trace, ) from khoj.utils import state -from khoj.utils.helpers import in_debug_mode +from khoj.utils.helpers import get_chat_usage_metrics, in_debug_mode logger = logging.getLogger(__name__) @@ -63,27 +65,34 @@ def completion_with_backoff( if os.getenv("KHOJ_LLM_SEED"): model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED")) - chat = client.chat.completions.create( - stream=stream, + chat: ChatCompletion | openai.Stream[ChatCompletionChunk] = client.chat.completions.create( messages=formatted_messages, # type: ignore model=model, # type: ignore + stream=stream, + stream_options={"include_usage": True} if stream else {}, temperature=temperature, timeout=20, **(model_kwargs or dict()), ) - if not stream: - return chat.choices[0].message.content - aggregated_response = "" - for chunk in chat: - if len(chunk.choices) == 0: - continue - delta_chunk = chunk.choices[0].delta # type: ignore - if isinstance(delta_chunk, str): - aggregated_response += delta_chunk - elif delta_chunk.content: - aggregated_response += delta_chunk.content + if not stream: + chunk = chat + aggregated_response = chunk.choices[0].message.content + else: + for chunk in chat: + if len(chunk.choices) == 0: + continue + delta_chunk = chunk.choices[0].delta # type: ignore + if isinstance(delta_chunk, str): + aggregated_response += delta_chunk + elif delta_chunk.content: + aggregated_response += delta_chunk.content + + # Calculate cost of chat + input_tokens = chunk.usage.prompt_tokens if hasattr(chunk, "usage") and chunk.usage else 0 + output_tokens = chunk.usage.completion_tokens if hasattr(chunk, "usage") and chunk.usage else 0 + tracer["usage"] = get_chat_usage_metrics(model, input_tokens, output_tokens, tracer.get("usage")) # Save conversation trace tracer["chat_model"] = model @@ -162,10 +171,11 @@ def llm_thread( if os.getenv("KHOJ_LLM_SEED"): model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED")) - chat = client.chat.completions.create( - stream=stream, + chat: ChatCompletion | openai.Stream[ChatCompletionChunk] = client.chat.completions.create( messages=formatted_messages, model=model_name, # type: ignore + stream=stream, + stream_options={"include_usage": True} if stream else {}, temperature=temperature, timeout=20, **(model_kwargs or dict()), @@ -173,7 +183,8 @@ def llm_thread( aggregated_response = "" if not stream: - aggregated_response = chat.choices[0].message.content + chunk = chat + aggregated_response = chunk.choices[0].message.content g.send(aggregated_response) else: for chunk in chat: @@ -189,6 +200,11 @@ def llm_thread( aggregated_response += text_chunk g.send(text_chunk) + # Calculate cost of chat + input_tokens = chunk.usage.prompt_tokens if hasattr(chunk, "usage") and chunk.usage else 0 + output_tokens = chunk.usage.completion_tokens if hasattr(chunk, "usage") and chunk.usage else 0 + tracer["usage"] = get_chat_usage_metrics(model_name, input_tokens, output_tokens, tracer.get("usage")) + # Save conversation trace tracer["chat_model"] = model_name tracer["temperature"] = temperature diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index c02958c9..155a109c 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -5,7 +5,6 @@ import math import mimetypes import os import queue -import re import uuid from dataclasses import dataclass from datetime import datetime @@ -57,7 +56,7 @@ model_to_prompt_size = { "gemini-1.5-flash": 20000, "gemini-1.5-pro": 20000, # Anthropic Models - "claude-3-5-sonnet-20240620": 20000, + "claude-3-5-sonnet-20241022": 20000, "claude-3-5-haiku-20241022": 20000, # Offline Models "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF": 20000, @@ -213,6 +212,8 @@ class ChatEvent(Enum): REFERENCES = "references" STATUS = "status" METADATA = "metadata" + USAGE = "usage" + END_RESPONSE = "end_response" def message_to_log( diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index e9f31ea9..8c8bf765 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -667,27 +667,37 @@ async def chat( finally: yield event_delimiter - async def send_llm_response(response: str): + async def send_llm_response(response: str, usage: dict = None): + # Send Chat Response async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""): yield result async for result in send_event(ChatEvent.MESSAGE, response): yield result async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""): yield result + # Send Usage Metadata once llm interactions are complete + if usage: + async for event in send_event(ChatEvent.USAGE, usage): + yield event + async for result in send_event(ChatEvent.END_RESPONSE, ""): + yield result def collect_telemetry(): # Gather chat response telemetry nonlocal chat_metadata latency = time.perf_counter() - start_time cmd_set = set([cmd.value for cmd in conversation_commands]) + cost = (tracer.get("usage", {}) or {}).get("cost", 0) chat_metadata = chat_metadata or {} chat_metadata["conversation_command"] = cmd_set chat_metadata["agent"] = conversation.agent.slug if conversation and conversation.agent else None chat_metadata["latency"] = f"{latency:.3f}" chat_metadata["ttft_latency"] = f"{ttft:.3f}" + chat_metadata["usage"] = tracer.get("usage") logger.info(f"Chat response time to first token: {ttft:.3f} seconds") logger.info(f"Chat response total time: {latency:.3f} seconds") + logger.info(f"Chat response cost: ${cost:.5f}") update_telemetry_state( request=request, telemetry_type="api", @@ -699,7 +709,7 @@ async def chat( ) if is_query_empty(q): - async for result in send_llm_response("Please ask your query to get started."): + async for result in send_llm_response("Please ask your query to get started.", tracer.get("usage")): yield result return @@ -713,7 +723,7 @@ async def chat( create_new=body.create_new, ) if not conversation: - async for result in send_llm_response(f"Conversation {conversation_id} not found"): + async for result in send_llm_response(f"Conversation {conversation_id} not found", tracer.get("usage")): yield result return conversation_id = conversation.id @@ -777,7 +787,7 @@ async def chat( await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd) q = q.replace(f"/{cmd.value}", "").strip() except HTTPException as e: - async for result in send_llm_response(str(e.detail)): + async for result in send_llm_response(str(e.detail), tracer.get("usage")): yield result return @@ -834,7 +844,7 @@ async def chat( agent_has_entries = await EntryAdapters.aagent_has_entries(agent) if len(file_filters) == 0 and not agent_has_entries: response_log = "No files selected for summarization. Please add files using the section on the left." - async for result in send_llm_response(response_log): + async for result in send_llm_response(response_log, tracer.get("usage")): yield result else: async for response in generate_summary_from_files( @@ -853,7 +863,7 @@ async def chat( else: if isinstance(response, str): response_log = response - async for result in send_llm_response(response): + async for result in send_llm_response(response, tracer.get("usage")): yield result await sync_to_async(save_to_conversation_log)( @@ -880,7 +890,7 @@ async def chat( conversation_config = await ConversationAdapters.aget_default_conversation_config(user) model_type = conversation_config.model_type formatted_help = help_message.format(model=model_type, version=state.khoj_version, device=get_device()) - async for result in send_llm_response(formatted_help): + async for result in send_llm_response(formatted_help, tracer.get("usage")): yield result return # Adding specification to search online specifically on khoj.dev pages. @@ -895,7 +905,7 @@ async def chat( except Exception as e: logger.error(f"Error scheduling task {q} for {user.email}: {e}") error_message = f"Unable to create automation. Ensure the automation doesn't already exist." - async for result in send_llm_response(error_message): + async for result in send_llm_response(error_message, tracer.get("usage")): yield result return @@ -916,7 +926,7 @@ async def chat( raw_query_files=raw_query_files, tracer=tracer, ) - async for result in send_llm_response(llm_response): + async for result in send_llm_response(llm_response, tracer.get("usage")): yield result return @@ -963,7 +973,7 @@ async def chat( yield result if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user): - async for result in send_llm_response(f"{no_entries_found.format()}"): + async for result in send_llm_response(f"{no_entries_found.format()}", tracer.get("usage")): yield result return @@ -1105,7 +1115,7 @@ async def chat( "detail": improved_image_prompt, "image": None, } - async for result in send_llm_response(json.dumps(content_obj)): + async for result in send_llm_response(json.dumps(content_obj), tracer.get("usage")): yield result return @@ -1132,7 +1142,7 @@ async def chat( "inferredQueries": [improved_image_prompt], "image": generated_image, } - async for result in send_llm_response(json.dumps(content_obj)): + async for result in send_llm_response(json.dumps(content_obj), tracer.get("usage")): yield result return @@ -1166,7 +1176,7 @@ async def chat( diagram_description = excalidraw_diagram_description else: error_message = "Failed to generate diagram. Please try again later." - async for result in send_llm_response(error_message): + async for result in send_llm_response(error_message, tracer.get("usage")): yield result await sync_to_async(save_to_conversation_log)( @@ -1213,7 +1223,7 @@ async def chat( tracer=tracer, ) - async for result in send_llm_response(json.dumps(content_obj)): + async for result in send_llm_response(json.dumps(content_obj), tracer.get("usage")): yield result return @@ -1252,6 +1262,11 @@ async def chat( if item is None: async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""): yield result + # Send Usage Metadata once llm interactions are complete + async for event in send_event(ChatEvent.USAGE, tracer.get("usage")): + yield event + async for result in send_event(ChatEvent.END_RESPONSE, ""): + yield result logger.debug("Finished streaming response") return if not connection_alive or not continue_stream: diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index b011106e..4fd30de9 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -1770,6 +1770,7 @@ Manage your automations [here](/automations). class MessageProcessor: def __init__(self): self.references = {} + self.usage = {} self.raw_response = "" def convert_message_chunk_to_json(self, raw_chunk: str) -> Dict[str, Any]: @@ -1793,6 +1794,8 @@ class MessageProcessor: chunk_type = ChatEvent(chunk["type"]) if chunk_type == ChatEvent.REFERENCES: self.references = chunk["data"] + elif chunk_type == ChatEvent.USAGE: + self.usage = chunk["data"] elif chunk_type == ChatEvent.MESSAGE: chunk_data = chunk["data"] if isinstance(chunk_data, dict): @@ -1837,7 +1840,7 @@ async def read_chat_stream(response_iterator: AsyncGenerator[str, None]) -> Dict if buffer: processor.process_message_chunk(buffer) - return {"response": processor.raw_response, "references": processor.references} + return {"response": processor.raw_response, "references": processor.references, "usage": processor.usage} def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False): diff --git a/src/khoj/utils/constants.py b/src/khoj/utils/constants.py index 591a0f08..40320373 100644 --- a/src/khoj/utils/constants.py +++ b/src/khoj/utils/constants.py @@ -1,4 +1,5 @@ from pathlib import Path +from typing import Dict app_root_directory = Path(__file__).parent.parent.parent web_directory = app_root_directory / "khoj/interface/web/" @@ -31,3 +32,19 @@ default_config = { "image": {"encoder": "sentence-transformers/clip-ViT-B-32", "model_directory": "~/.khoj/search/image/"}, }, } + +model_to_cost: Dict[str, Dict[str, float]] = { + # OpenAI Pricing: https://openai.com/api/pricing/ + "gpt-4o": {"input": 2.50, "output": 10.00}, + "gpt-4o-mini": {"input": 0.15, "output": 0.60}, + "o1-preview": {"input": 15.0, "output": 60.00}, + "o1-mini": {"input": 3.0, "output": 12.0}, + # Gemini Pricing: https://ai.google.dev/pricing + "gemini-1.5-flash": {"input": 0.075, "output": 0.30}, + "gemini-1.5-flash-002": {"input": 0.075, "output": 0.30}, + "gemini-1.5-pro": {"input": 1.25, "output": 5.00}, + "gemini-1.5-pro-002": {"input": 1.25, "output": 5.00}, + # Anthropic Pricing: https://www.anthropic.com/pricing#anthropic-api_ + "claude-3-5-sonnet-20241022": {"input": 3.0, "output": 15.0}, + "claude-3-5-haiku-20241022": {"input": 1.0, "output": 5.0}, +} diff --git a/src/khoj/utils/helpers.py b/src/khoj/utils/helpers.py index d1617d79..02cd7a92 100644 --- a/src/khoj/utils/helpers.py +++ b/src/khoj/utils/helpers.py @@ -540,3 +540,27 @@ def get_country_code_from_timezone(tz: str) -> str: def get_country_name_from_timezone(tz: str) -> str: """Get country name from timezone""" return country_names.get(get_country_code_from_timezone(tz), "United States") + + +def get_cost_of_chat_message(model_name: str, input_tokens: int = 0, output_tokens: int = 0, prev_cost: float = 0.0): + """ + Calculate cost of chat message based on input and output tokens + """ + + # Calculate cost of input and output tokens. Costs are per million tokens + input_cost = constants.model_to_cost.get(model_name, {}).get("input", 0) * (input_tokens / 1e6) + output_cost = constants.model_to_cost.get(model_name, {}).get("output", 0) * (output_tokens / 1e6) + + return input_cost + output_cost + prev_cost + + +def get_chat_usage_metrics(model_name: str, input_tokens: int = 0, output_tokens: int = 0, usage: dict = {}): + """ + Get usage metrics for chat message based on input and output tokens + """ + prev_usage = usage or {"input_tokens": 0, "output_tokens": 0, "cost": 0.0} + return { + "input_tokens": prev_usage["input_tokens"] + input_tokens, + "output_tokens": prev_usage["output_tokens"] + output_tokens, + "cost": get_cost_of_chat_message(model_name, input_tokens, output_tokens, prev_cost=prev_usage["cost"]), + } diff --git a/tests/evals/eval.py b/tests/evals/eval.py index 9c018b76..30c842b2 100644 --- a/tests/evals/eval.py +++ b/tests/evals/eval.py @@ -6,13 +6,15 @@ import os import time from datetime import datetime from io import StringIO +from textwrap import dedent +from threading import Lock from typing import Any, Dict import pandas as pd import requests from datasets import Dataset, load_dataset -from khoj.utils.helpers import is_none_or_empty, timer +from khoj.utils.helpers import get_cost_of_chat_message, is_none_or_empty, timer # Configure root logger logging.basicConfig(level=logging.INFO, format="%(message)s") @@ -38,6 +40,28 @@ BATCH_SIZE = int( SLEEP_SECONDS = 3 if KHOJ_MODE == "general" else 1 # Sleep between API calls to avoid rate limiting +class Counter: + """Thread-safe counter for tracking metrics""" + + def __init__(self, value=0.0): + self.value = value + self.lock = Lock() + + def add(self, amount): + with self.lock: + self.value += amount + + def get(self): + with self.lock: + return self.value + + +# Track running metrics while evaluating +running_cost = Counter() +running_true_count = Counter(0) +running_false_count = Counter(0) + + def load_frames_dataset(): """ Load the Google FRAMES benchmark dataset from HuggingFace @@ -104,25 +128,31 @@ def load_simpleqa_dataset(): return None -def get_agent_response(prompt: str) -> str: +def get_agent_response(prompt: str) -> Dict[str, Any]: """Get response from the Khoj API""" + # Set headers + headers = {"Content-Type": "application/json"} + if not is_none_or_empty(KHOJ_API_KEY): + headers["Authorization"] = f"Bearer {KHOJ_API_KEY}" + try: response = requests.post( KHOJ_CHAT_API_URL, - headers={"Content-Type": "application/json", "Authorization": f"Bearer {KHOJ_API_KEY}"}, + headers=headers, json={ "q": prompt, "create_new": True, }, ) response.raise_for_status() - return response.json().get("response", "") + response_json = response.json() + return {"response": response_json.get("response", ""), "usage": response_json.get("usage", {})} except Exception as e: logger.error(f"Error getting agent response: {e}") - return "" + return {"response": "", "usage": {}} -def evaluate_response(query: str, agent_response: str, ground_truth: str) -> Dict[str, Any]: +def evaluate_response(query: str, agent_response: str, ground_truth: str) -> tuple[bool | None, str, float]: """Evaluate Khoj response against benchmark ground truth using Gemini""" evaluation_prompt = f""" Compare the following agent response with the ground truth answer. @@ -147,10 +177,16 @@ def evaluate_response(query: str, agent_response: str, ground_truth: str) -> Dic }, ) response.raise_for_status() + response_json = response.json() + + # Update cost of evaluation + input_tokens = response_json["usageMetadata"]["promptTokenCount"] + ouput_tokens = response_json["usageMetadata"]["candidatesTokenCount"] + cost = get_cost_of_chat_message(GEMINI_EVAL_MODEL, input_tokens, ouput_tokens) # Parse evaluation response eval_response: dict[str, str] = json.loads( - clean_json(response.json()["candidates"][0]["content"]["parts"][0]["text"]) + clean_json(response_json["candidates"][0]["content"]["parts"][0]["text"]) ) decision = str(eval_response.get("decision", "")).upper() == "TRUE" explanation = eval_response.get("explanation", "") @@ -158,13 +194,14 @@ def evaluate_response(query: str, agent_response: str, ground_truth: str) -> Dic if "503 Service Error" in explanation: decision = None # Extract decision and explanation from structured response - return decision, explanation + return decision, explanation, cost except Exception as e: logger.error(f"Error in evaluation: {e}") - return None, f"Evaluation failed: {str(e)}" + return None, f"Evaluation failed: {str(e)}", 0.0 def process_batch(batch, batch_start, results, dataset_length): + global running_cost for idx, (prompt, answer, reasoning_type) in enumerate(batch): current_index = batch_start + idx logger.info(f"Processing example: {current_index}/{dataset_length}") @@ -173,14 +210,16 @@ def process_batch(batch, batch_start, results, dataset_length): prompt = f"/{KHOJ_MODE} {prompt}" if KHOJ_MODE and not prompt.startswith(f"/{KHOJ_MODE}") else prompt # Get agent response - agent_response = get_agent_response(prompt) + response = get_agent_response(prompt) + agent_response = response["response"] + agent_usage = response["usage"] # Evaluate response if is_none_or_empty(agent_response): decision = None explanation = "Agent response is empty. This maybe due to a service error." else: - decision, explanation = evaluate_response(prompt, agent_response, answer) + decision, explanation, eval_cost = evaluate_response(prompt, agent_response, answer) # Store results results.append( @@ -192,17 +231,38 @@ def process_batch(batch, batch_start, results, dataset_length): "evaluation_decision": decision, "evaluation_explanation": explanation, "reasoning_type": reasoning_type, + "usage": agent_usage, } ) - # Log results + # Update running cost + query_cost = float(agent_usage.get("cost", 0.0)) + running_cost.add(query_cost + eval_cost) + + # Update running accuracy + running_accuracy = 0.0 + if decision is not None: + running_true_count.add(1) if decision == True else running_false_count.add(1) + running_accuracy = running_true_count.get() / (running_true_count.get() + running_false_count.get()) + + ## Log results decision_color = {True: "green", None: "blue", False: "red"}[decision] colored_decision = color_text(str(decision), decision_color) - logger.info( - f"Decision: {colored_decision}\nQuestion: {prompt}\nExpected Answer: {answer}\nAgent Answer: {agent_response}\nExplanation: {explanation}\n" - ) + result_to_print = f""" +--------- +Decision: {colored_decision} +Accuracy: {running_accuracy:.2%} +Question: {prompt} +Expected Answer: {answer} +Agent Answer: {agent_response} +Explanation: {explanation} +Cost: ${running_cost.get():.5f} (Query: ${query_cost:.5f}, Eval: ${eval_cost:.5f}) +--------- + """ + logger.info(dedent(result_to_print).lstrip()) - time.sleep(SLEEP_SECONDS) # Rate limiting + # Sleep between API calls to avoid rate limiting + time.sleep(SLEEP_SECONDS) def color_text(text, color): @@ -281,17 +341,18 @@ def main(): lambda x: (x == True).mean() ) - # Print summary + # Collect summary colored_accuracy = color_text(f"{accuracy:.2%}", "blue") - logger.info(f"\nOverall Accuracy: {colored_accuracy}") - logger.info(f"\nAccuracy by Reasoning Type:\n{reasoning_type_accuracy}") - - # Save summary to file + colored_accuracy_str = f"Overall Accuracy: {colored_accuracy} on {args.dataset.title()} dataset." + accuracy_str = f"Overall Accuracy: {accuracy:.2%} on {args.dataset}." + accuracy_by_reasoning = f"Accuracy by Reasoning Type:\n{reasoning_type_accuracy}" + cost = f"Total Cost: ${running_cost.get():.5f}." sample_type = f"Sampling Type: {SAMPLE_SIZE} samples." if SAMPLE_SIZE else "Whole dataset." sample_type += " Randomized." if RANDOMIZE else "" - summary = ( - f"Overall Accuracy: {accuracy:.2%}\n\nAccuracy by Reasoning Type:\n{reasoning_type_accuracy}\n\n{sample_type}\n" - ) + logger.info(f"\n{colored_accuracy_str}\n\n{accuracy_by_reasoning}\n\n{cost}\n\n{sample_type}\n") + + # Save summary to file + summary = f"{accuracy_str}\n\n{accuracy_by_reasoning}\n\n{cost}\n\n{sample_type}\n" summary_file = args.output.replace(".csv", ".txt") if args.output else None summary_file = ( summary_file or f"{args.dataset}_evaluation_summary_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.txt"