mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 15:38:55 +01:00
Track Usage Metrics in Chat API. Track Running Cost, Accuracy in Evals (#985)
- Track, return cost and usage metrics in chat api response Track input, output token usage and cost of interactions with openai, anthropic and google chat models for each call to the khoj chat api - Collect, display and store costs & accuracy of eval run currently in progress This provides more insight into eval runs during execution instead of having to wait until the eval run completes.
This commit is contained in:
commit
6f1adcfe67
12 changed files with 230 additions and 67 deletions
11
.github/workflows/run_evals.yml
vendored
11
.github/workflows/run_evals.yml
vendored
|
@ -1,9 +1,10 @@
|
||||||
name: Run Khoj Evals
|
name: eval
|
||||||
|
|
||||||
on:
|
on:
|
||||||
# Run on every releases
|
# Run on every release
|
||||||
release:
|
push:
|
||||||
types: [published]
|
tags:
|
||||||
|
- "*"
|
||||||
# Allow manual triggers from GitHub UI
|
# Allow manual triggers from GitHub UI
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
inputs:
|
inputs:
|
||||||
|
@ -82,7 +83,7 @@ jobs:
|
||||||
sed -i 's/dynamic = \["version"\]/version = "${{ steps.hatch.outputs.version }}"/' pyproject.toml
|
sed -i 's/dynamic = \["version"\]/version = "${{ steps.hatch.outputs.version }}"/' pyproject.toml
|
||||||
pip install --upgrade .[dev]
|
pip install --upgrade .[dev]
|
||||||
|
|
||||||
- name: 📝 Run Evals
|
- name: 📝 Run Eval
|
||||||
env:
|
env:
|
||||||
KHOJ_MODE: ${{ matrix.khoj_mode }}
|
KHOJ_MODE: ${{ matrix.khoj_mode }}
|
||||||
SAMPLE_SIZE: ${{ inputs.sample_size }}
|
SAMPLE_SIZE: ${{ inputs.sample_size }}
|
||||||
|
|
|
@ -945,7 +945,7 @@ export class KhojChatView extends KhojPaneView {
|
||||||
console.log("Started streaming", new Date());
|
console.log("Started streaming", new Date());
|
||||||
} else if (chunk.type === 'end_llm_response') {
|
} else if (chunk.type === 'end_llm_response') {
|
||||||
console.log("Stopped streaming", new Date());
|
console.log("Stopped streaming", new Date());
|
||||||
|
} else if (chunk.type === 'end_response') {
|
||||||
// Automatically respond with voice if the subscribed user has sent voice message
|
// Automatically respond with voice if the subscribed user has sent voice message
|
||||||
if (this.chatMessageState.isVoice && this.setting.userInfo?.is_active)
|
if (this.chatMessageState.isVoice && this.setting.userInfo?.is_active)
|
||||||
this.textToSpeech(this.chatMessageState.rawResponse);
|
this.textToSpeech(this.chatMessageState.rawResponse);
|
||||||
|
|
|
@ -133,7 +133,7 @@ export function processMessageChunk(
|
||||||
console.log(`Started streaming: ${new Date()}`);
|
console.log(`Started streaming: ${new Date()}`);
|
||||||
} else if (chunk.type === "end_llm_response") {
|
} else if (chunk.type === "end_llm_response") {
|
||||||
console.log(`Completed streaming: ${new Date()}`);
|
console.log(`Completed streaming: ${new Date()}`);
|
||||||
|
} else if (chunk.type === "end_response") {
|
||||||
// Append any references after all the data has been streamed
|
// Append any references after all the data has been streamed
|
||||||
if (codeContext) currentMessage.codeContext = codeContext;
|
if (codeContext) currentMessage.codeContext = codeContext;
|
||||||
if (onlineContext) currentMessage.onlineContext = onlineContext;
|
if (onlineContext) currentMessage.onlineContext = onlineContext;
|
||||||
|
|
|
@ -18,7 +18,7 @@ from khoj.processor.conversation.utils import (
|
||||||
get_image_from_url,
|
get_image_from_url,
|
||||||
)
|
)
|
||||||
from khoj.utils import state
|
from khoj.utils import state
|
||||||
from khoj.utils.helpers import in_debug_mode, is_none_or_empty
|
from khoj.utils.helpers import get_chat_usage_metrics, in_debug_mode, is_none_or_empty
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -59,6 +59,7 @@ def anthropic_completion_with_backoff(
|
||||||
aggregated_response = "{" if response_type == "json_object" else ""
|
aggregated_response = "{" if response_type == "json_object" else ""
|
||||||
max_tokens = max_tokens or DEFAULT_MAX_TOKENS_ANTHROPIC
|
max_tokens = max_tokens or DEFAULT_MAX_TOKENS_ANTHROPIC
|
||||||
|
|
||||||
|
final_message = None
|
||||||
model_kwargs = model_kwargs or dict()
|
model_kwargs = model_kwargs or dict()
|
||||||
if system_prompt:
|
if system_prompt:
|
||||||
model_kwargs["system"] = system_prompt
|
model_kwargs["system"] = system_prompt
|
||||||
|
@ -73,6 +74,12 @@ def anthropic_completion_with_backoff(
|
||||||
) as stream:
|
) as stream:
|
||||||
for text in stream.text_stream:
|
for text in stream.text_stream:
|
||||||
aggregated_response += text
|
aggregated_response += text
|
||||||
|
final_message = stream.get_final_message()
|
||||||
|
|
||||||
|
# Calculate cost of chat
|
||||||
|
input_tokens = final_message.usage.input_tokens
|
||||||
|
output_tokens = final_message.usage.output_tokens
|
||||||
|
tracer["usage"] = get_chat_usage_metrics(model_name, input_tokens, output_tokens, tracer.get("usage"))
|
||||||
|
|
||||||
# Save conversation trace
|
# Save conversation trace
|
||||||
tracer["chat_model"] = model_name
|
tracer["chat_model"] = model_name
|
||||||
|
@ -126,6 +133,7 @@ def anthropic_llm_thread(
|
||||||
]
|
]
|
||||||
|
|
||||||
aggregated_response = ""
|
aggregated_response = ""
|
||||||
|
final_message = None
|
||||||
with client.messages.stream(
|
with client.messages.stream(
|
||||||
messages=formatted_messages,
|
messages=formatted_messages,
|
||||||
model=model_name, # type: ignore
|
model=model_name, # type: ignore
|
||||||
|
@ -138,6 +146,12 @@ def anthropic_llm_thread(
|
||||||
for text in stream.text_stream:
|
for text in stream.text_stream:
|
||||||
aggregated_response += text
|
aggregated_response += text
|
||||||
g.send(text)
|
g.send(text)
|
||||||
|
final_message = stream.get_final_message()
|
||||||
|
|
||||||
|
# Calculate cost of chat
|
||||||
|
input_tokens = final_message.usage.input_tokens
|
||||||
|
output_tokens = final_message.usage.output_tokens
|
||||||
|
tracer["usage"] = get_chat_usage_metrics(model_name, input_tokens, output_tokens, tracer.get("usage"))
|
||||||
|
|
||||||
# Save conversation trace
|
# Save conversation trace
|
||||||
tracer["chat_model"] = model_name
|
tracer["chat_model"] = model_name
|
||||||
|
|
|
@ -25,7 +25,7 @@ from khoj.processor.conversation.utils import (
|
||||||
get_image_from_url,
|
get_image_from_url,
|
||||||
)
|
)
|
||||||
from khoj.utils import state
|
from khoj.utils import state
|
||||||
from khoj.utils.helpers import in_debug_mode, is_none_or_empty
|
from khoj.utils.helpers import get_chat_usage_metrics, in_debug_mode, is_none_or_empty
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -68,6 +68,7 @@ def gemini_completion_with_backoff(
|
||||||
response = chat_session.send_message(formatted_messages[-1]["parts"])
|
response = chat_session.send_message(formatted_messages[-1]["parts"])
|
||||||
response_text = response.text
|
response_text = response.text
|
||||||
except StopCandidateException as e:
|
except StopCandidateException as e:
|
||||||
|
response = None
|
||||||
response_text, _ = handle_gemini_response(e.args)
|
response_text, _ = handle_gemini_response(e.args)
|
||||||
# Respond with reason for stopping
|
# Respond with reason for stopping
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
@ -75,6 +76,11 @@ def gemini_completion_with_backoff(
|
||||||
+ f"Last Message by {messages[-1].role}: {messages[-1].content}"
|
+ f"Last Message by {messages[-1].role}: {messages[-1].content}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Aggregate cost of chat
|
||||||
|
input_tokens = response.usage_metadata.prompt_token_count if response else 0
|
||||||
|
output_tokens = response.usage_metadata.candidates_token_count if response else 0
|
||||||
|
tracer["usage"] = get_chat_usage_metrics(model_name, input_tokens, output_tokens, tracer.get("usage"))
|
||||||
|
|
||||||
# Save conversation trace
|
# Save conversation trace
|
||||||
tracer["chat_model"] = model_name
|
tracer["chat_model"] = model_name
|
||||||
tracer["temperature"] = temperature
|
tracer["temperature"] = temperature
|
||||||
|
@ -146,6 +152,11 @@ def gemini_llm_thread(
|
||||||
if stopped:
|
if stopped:
|
||||||
raise StopCandidateException(message)
|
raise StopCandidateException(message)
|
||||||
|
|
||||||
|
# Calculate cost of chat
|
||||||
|
input_tokens = chunk.usage_metadata.prompt_token_count
|
||||||
|
output_tokens = chunk.usage_metadata.candidates_token_count
|
||||||
|
tracer["usage"] = get_chat_usage_metrics(model_name, input_tokens, output_tokens, tracer.get("usage"))
|
||||||
|
|
||||||
# Save conversation trace
|
# Save conversation trace
|
||||||
tracer["chat_model"] = model_name
|
tracer["chat_model"] = model_name
|
||||||
tracer["temperature"] = temperature
|
tracer["temperature"] = temperature
|
||||||
|
|
|
@ -4,6 +4,8 @@ from threading import Thread
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
|
from openai.types.chat.chat_completion import ChatCompletion
|
||||||
|
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||||
from tenacity import (
|
from tenacity import (
|
||||||
before_sleep_log,
|
before_sleep_log,
|
||||||
retry,
|
retry,
|
||||||
|
@ -18,7 +20,7 @@ from khoj.processor.conversation.utils import (
|
||||||
commit_conversation_trace,
|
commit_conversation_trace,
|
||||||
)
|
)
|
||||||
from khoj.utils import state
|
from khoj.utils import state
|
||||||
from khoj.utils.helpers import in_debug_mode
|
from khoj.utils.helpers import get_chat_usage_metrics, in_debug_mode
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -63,19 +65,21 @@ def completion_with_backoff(
|
||||||
if os.getenv("KHOJ_LLM_SEED"):
|
if os.getenv("KHOJ_LLM_SEED"):
|
||||||
model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED"))
|
model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED"))
|
||||||
|
|
||||||
chat = client.chat.completions.create(
|
chat: ChatCompletion | openai.Stream[ChatCompletionChunk] = client.chat.completions.create(
|
||||||
stream=stream,
|
|
||||||
messages=formatted_messages, # type: ignore
|
messages=formatted_messages, # type: ignore
|
||||||
model=model, # type: ignore
|
model=model, # type: ignore
|
||||||
|
stream=stream,
|
||||||
|
stream_options={"include_usage": True} if stream else {},
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
timeout=20,
|
timeout=20,
|
||||||
**(model_kwargs or dict()),
|
**(model_kwargs or dict()),
|
||||||
)
|
)
|
||||||
|
|
||||||
if not stream:
|
|
||||||
return chat.choices[0].message.content
|
|
||||||
|
|
||||||
aggregated_response = ""
|
aggregated_response = ""
|
||||||
|
if not stream:
|
||||||
|
chunk = chat
|
||||||
|
aggregated_response = chunk.choices[0].message.content
|
||||||
|
else:
|
||||||
for chunk in chat:
|
for chunk in chat:
|
||||||
if len(chunk.choices) == 0:
|
if len(chunk.choices) == 0:
|
||||||
continue
|
continue
|
||||||
|
@ -85,6 +89,11 @@ def completion_with_backoff(
|
||||||
elif delta_chunk.content:
|
elif delta_chunk.content:
|
||||||
aggregated_response += delta_chunk.content
|
aggregated_response += delta_chunk.content
|
||||||
|
|
||||||
|
# Calculate cost of chat
|
||||||
|
input_tokens = chunk.usage.prompt_tokens if hasattr(chunk, "usage") and chunk.usage else 0
|
||||||
|
output_tokens = chunk.usage.completion_tokens if hasattr(chunk, "usage") and chunk.usage else 0
|
||||||
|
tracer["usage"] = get_chat_usage_metrics(model, input_tokens, output_tokens, tracer.get("usage"))
|
||||||
|
|
||||||
# Save conversation trace
|
# Save conversation trace
|
||||||
tracer["chat_model"] = model
|
tracer["chat_model"] = model
|
||||||
tracer["temperature"] = temperature
|
tracer["temperature"] = temperature
|
||||||
|
@ -162,10 +171,11 @@ def llm_thread(
|
||||||
if os.getenv("KHOJ_LLM_SEED"):
|
if os.getenv("KHOJ_LLM_SEED"):
|
||||||
model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED"))
|
model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED"))
|
||||||
|
|
||||||
chat = client.chat.completions.create(
|
chat: ChatCompletion | openai.Stream[ChatCompletionChunk] = client.chat.completions.create(
|
||||||
stream=stream,
|
|
||||||
messages=formatted_messages,
|
messages=formatted_messages,
|
||||||
model=model_name, # type: ignore
|
model=model_name, # type: ignore
|
||||||
|
stream=stream,
|
||||||
|
stream_options={"include_usage": True} if stream else {},
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
timeout=20,
|
timeout=20,
|
||||||
**(model_kwargs or dict()),
|
**(model_kwargs or dict()),
|
||||||
|
@ -173,7 +183,8 @@ def llm_thread(
|
||||||
|
|
||||||
aggregated_response = ""
|
aggregated_response = ""
|
||||||
if not stream:
|
if not stream:
|
||||||
aggregated_response = chat.choices[0].message.content
|
chunk = chat
|
||||||
|
aggregated_response = chunk.choices[0].message.content
|
||||||
g.send(aggregated_response)
|
g.send(aggregated_response)
|
||||||
else:
|
else:
|
||||||
for chunk in chat:
|
for chunk in chat:
|
||||||
|
@ -189,6 +200,11 @@ def llm_thread(
|
||||||
aggregated_response += text_chunk
|
aggregated_response += text_chunk
|
||||||
g.send(text_chunk)
|
g.send(text_chunk)
|
||||||
|
|
||||||
|
# Calculate cost of chat
|
||||||
|
input_tokens = chunk.usage.prompt_tokens if hasattr(chunk, "usage") and chunk.usage else 0
|
||||||
|
output_tokens = chunk.usage.completion_tokens if hasattr(chunk, "usage") and chunk.usage else 0
|
||||||
|
tracer["usage"] = get_chat_usage_metrics(model_name, input_tokens, output_tokens, tracer.get("usage"))
|
||||||
|
|
||||||
# Save conversation trace
|
# Save conversation trace
|
||||||
tracer["chat_model"] = model_name
|
tracer["chat_model"] = model_name
|
||||||
tracer["temperature"] = temperature
|
tracer["temperature"] = temperature
|
||||||
|
|
|
@ -5,7 +5,6 @@ import math
|
||||||
import mimetypes
|
import mimetypes
|
||||||
import os
|
import os
|
||||||
import queue
|
import queue
|
||||||
import re
|
|
||||||
import uuid
|
import uuid
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
@ -57,7 +56,7 @@ model_to_prompt_size = {
|
||||||
"gemini-1.5-flash": 20000,
|
"gemini-1.5-flash": 20000,
|
||||||
"gemini-1.5-pro": 20000,
|
"gemini-1.5-pro": 20000,
|
||||||
# Anthropic Models
|
# Anthropic Models
|
||||||
"claude-3-5-sonnet-20240620": 20000,
|
"claude-3-5-sonnet-20241022": 20000,
|
||||||
"claude-3-5-haiku-20241022": 20000,
|
"claude-3-5-haiku-20241022": 20000,
|
||||||
# Offline Models
|
# Offline Models
|
||||||
"bartowski/Meta-Llama-3.1-8B-Instruct-GGUF": 20000,
|
"bartowski/Meta-Llama-3.1-8B-Instruct-GGUF": 20000,
|
||||||
|
@ -213,6 +212,8 @@ class ChatEvent(Enum):
|
||||||
REFERENCES = "references"
|
REFERENCES = "references"
|
||||||
STATUS = "status"
|
STATUS = "status"
|
||||||
METADATA = "metadata"
|
METADATA = "metadata"
|
||||||
|
USAGE = "usage"
|
||||||
|
END_RESPONSE = "end_response"
|
||||||
|
|
||||||
|
|
||||||
def message_to_log(
|
def message_to_log(
|
||||||
|
|
|
@ -667,27 +667,37 @@ async def chat(
|
||||||
finally:
|
finally:
|
||||||
yield event_delimiter
|
yield event_delimiter
|
||||||
|
|
||||||
async def send_llm_response(response: str):
|
async def send_llm_response(response: str, usage: dict = None):
|
||||||
|
# Send Chat Response
|
||||||
async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""):
|
async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""):
|
||||||
yield result
|
yield result
|
||||||
async for result in send_event(ChatEvent.MESSAGE, response):
|
async for result in send_event(ChatEvent.MESSAGE, response):
|
||||||
yield result
|
yield result
|
||||||
async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""):
|
async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""):
|
||||||
yield result
|
yield result
|
||||||
|
# Send Usage Metadata once llm interactions are complete
|
||||||
|
if usage:
|
||||||
|
async for event in send_event(ChatEvent.USAGE, usage):
|
||||||
|
yield event
|
||||||
|
async for result in send_event(ChatEvent.END_RESPONSE, ""):
|
||||||
|
yield result
|
||||||
|
|
||||||
def collect_telemetry():
|
def collect_telemetry():
|
||||||
# Gather chat response telemetry
|
# Gather chat response telemetry
|
||||||
nonlocal chat_metadata
|
nonlocal chat_metadata
|
||||||
latency = time.perf_counter() - start_time
|
latency = time.perf_counter() - start_time
|
||||||
cmd_set = set([cmd.value for cmd in conversation_commands])
|
cmd_set = set([cmd.value for cmd in conversation_commands])
|
||||||
|
cost = (tracer.get("usage", {}) or {}).get("cost", 0)
|
||||||
chat_metadata = chat_metadata or {}
|
chat_metadata = chat_metadata or {}
|
||||||
chat_metadata["conversation_command"] = cmd_set
|
chat_metadata["conversation_command"] = cmd_set
|
||||||
chat_metadata["agent"] = conversation.agent.slug if conversation and conversation.agent else None
|
chat_metadata["agent"] = conversation.agent.slug if conversation and conversation.agent else None
|
||||||
chat_metadata["latency"] = f"{latency:.3f}"
|
chat_metadata["latency"] = f"{latency:.3f}"
|
||||||
chat_metadata["ttft_latency"] = f"{ttft:.3f}"
|
chat_metadata["ttft_latency"] = f"{ttft:.3f}"
|
||||||
|
chat_metadata["usage"] = tracer.get("usage")
|
||||||
|
|
||||||
logger.info(f"Chat response time to first token: {ttft:.3f} seconds")
|
logger.info(f"Chat response time to first token: {ttft:.3f} seconds")
|
||||||
logger.info(f"Chat response total time: {latency:.3f} seconds")
|
logger.info(f"Chat response total time: {latency:.3f} seconds")
|
||||||
|
logger.info(f"Chat response cost: ${cost:.5f}")
|
||||||
update_telemetry_state(
|
update_telemetry_state(
|
||||||
request=request,
|
request=request,
|
||||||
telemetry_type="api",
|
telemetry_type="api",
|
||||||
|
@ -699,7 +709,7 @@ async def chat(
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_query_empty(q):
|
if is_query_empty(q):
|
||||||
async for result in send_llm_response("Please ask your query to get started."):
|
async for result in send_llm_response("Please ask your query to get started.", tracer.get("usage")):
|
||||||
yield result
|
yield result
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -713,7 +723,7 @@ async def chat(
|
||||||
create_new=body.create_new,
|
create_new=body.create_new,
|
||||||
)
|
)
|
||||||
if not conversation:
|
if not conversation:
|
||||||
async for result in send_llm_response(f"Conversation {conversation_id} not found"):
|
async for result in send_llm_response(f"Conversation {conversation_id} not found", tracer.get("usage")):
|
||||||
yield result
|
yield result
|
||||||
return
|
return
|
||||||
conversation_id = conversation.id
|
conversation_id = conversation.id
|
||||||
|
@ -777,7 +787,7 @@ async def chat(
|
||||||
await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd)
|
await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd)
|
||||||
q = q.replace(f"/{cmd.value}", "").strip()
|
q = q.replace(f"/{cmd.value}", "").strip()
|
||||||
except HTTPException as e:
|
except HTTPException as e:
|
||||||
async for result in send_llm_response(str(e.detail)):
|
async for result in send_llm_response(str(e.detail), tracer.get("usage")):
|
||||||
yield result
|
yield result
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -834,7 +844,7 @@ async def chat(
|
||||||
agent_has_entries = await EntryAdapters.aagent_has_entries(agent)
|
agent_has_entries = await EntryAdapters.aagent_has_entries(agent)
|
||||||
if len(file_filters) == 0 and not agent_has_entries:
|
if len(file_filters) == 0 and not agent_has_entries:
|
||||||
response_log = "No files selected for summarization. Please add files using the section on the left."
|
response_log = "No files selected for summarization. Please add files using the section on the left."
|
||||||
async for result in send_llm_response(response_log):
|
async for result in send_llm_response(response_log, tracer.get("usage")):
|
||||||
yield result
|
yield result
|
||||||
else:
|
else:
|
||||||
async for response in generate_summary_from_files(
|
async for response in generate_summary_from_files(
|
||||||
|
@ -853,7 +863,7 @@ async def chat(
|
||||||
else:
|
else:
|
||||||
if isinstance(response, str):
|
if isinstance(response, str):
|
||||||
response_log = response
|
response_log = response
|
||||||
async for result in send_llm_response(response):
|
async for result in send_llm_response(response, tracer.get("usage")):
|
||||||
yield result
|
yield result
|
||||||
|
|
||||||
await sync_to_async(save_to_conversation_log)(
|
await sync_to_async(save_to_conversation_log)(
|
||||||
|
@ -880,7 +890,7 @@ async def chat(
|
||||||
conversation_config = await ConversationAdapters.aget_default_conversation_config(user)
|
conversation_config = await ConversationAdapters.aget_default_conversation_config(user)
|
||||||
model_type = conversation_config.model_type
|
model_type = conversation_config.model_type
|
||||||
formatted_help = help_message.format(model=model_type, version=state.khoj_version, device=get_device())
|
formatted_help = help_message.format(model=model_type, version=state.khoj_version, device=get_device())
|
||||||
async for result in send_llm_response(formatted_help):
|
async for result in send_llm_response(formatted_help, tracer.get("usage")):
|
||||||
yield result
|
yield result
|
||||||
return
|
return
|
||||||
# Adding specification to search online specifically on khoj.dev pages.
|
# Adding specification to search online specifically on khoj.dev pages.
|
||||||
|
@ -895,7 +905,7 @@ async def chat(
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error scheduling task {q} for {user.email}: {e}")
|
logger.error(f"Error scheduling task {q} for {user.email}: {e}")
|
||||||
error_message = f"Unable to create automation. Ensure the automation doesn't already exist."
|
error_message = f"Unable to create automation. Ensure the automation doesn't already exist."
|
||||||
async for result in send_llm_response(error_message):
|
async for result in send_llm_response(error_message, tracer.get("usage")):
|
||||||
yield result
|
yield result
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -916,7 +926,7 @@ async def chat(
|
||||||
raw_query_files=raw_query_files,
|
raw_query_files=raw_query_files,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
async for result in send_llm_response(llm_response):
|
async for result in send_llm_response(llm_response, tracer.get("usage")):
|
||||||
yield result
|
yield result
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -963,7 +973,7 @@ async def chat(
|
||||||
yield result
|
yield result
|
||||||
|
|
||||||
if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user):
|
if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user):
|
||||||
async for result in send_llm_response(f"{no_entries_found.format()}"):
|
async for result in send_llm_response(f"{no_entries_found.format()}", tracer.get("usage")):
|
||||||
yield result
|
yield result
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -1105,7 +1115,7 @@ async def chat(
|
||||||
"detail": improved_image_prompt,
|
"detail": improved_image_prompt,
|
||||||
"image": None,
|
"image": None,
|
||||||
}
|
}
|
||||||
async for result in send_llm_response(json.dumps(content_obj)):
|
async for result in send_llm_response(json.dumps(content_obj), tracer.get("usage")):
|
||||||
yield result
|
yield result
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -1132,7 +1142,7 @@ async def chat(
|
||||||
"inferredQueries": [improved_image_prompt],
|
"inferredQueries": [improved_image_prompt],
|
||||||
"image": generated_image,
|
"image": generated_image,
|
||||||
}
|
}
|
||||||
async for result in send_llm_response(json.dumps(content_obj)):
|
async for result in send_llm_response(json.dumps(content_obj), tracer.get("usage")):
|
||||||
yield result
|
yield result
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -1166,7 +1176,7 @@ async def chat(
|
||||||
diagram_description = excalidraw_diagram_description
|
diagram_description = excalidraw_diagram_description
|
||||||
else:
|
else:
|
||||||
error_message = "Failed to generate diagram. Please try again later."
|
error_message = "Failed to generate diagram. Please try again later."
|
||||||
async for result in send_llm_response(error_message):
|
async for result in send_llm_response(error_message, tracer.get("usage")):
|
||||||
yield result
|
yield result
|
||||||
|
|
||||||
await sync_to_async(save_to_conversation_log)(
|
await sync_to_async(save_to_conversation_log)(
|
||||||
|
@ -1213,7 +1223,7 @@ async def chat(
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
async for result in send_llm_response(json.dumps(content_obj)):
|
async for result in send_llm_response(json.dumps(content_obj), tracer.get("usage")):
|
||||||
yield result
|
yield result
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -1252,6 +1262,11 @@ async def chat(
|
||||||
if item is None:
|
if item is None:
|
||||||
async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""):
|
async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""):
|
||||||
yield result
|
yield result
|
||||||
|
# Send Usage Metadata once llm interactions are complete
|
||||||
|
async for event in send_event(ChatEvent.USAGE, tracer.get("usage")):
|
||||||
|
yield event
|
||||||
|
async for result in send_event(ChatEvent.END_RESPONSE, ""):
|
||||||
|
yield result
|
||||||
logger.debug("Finished streaming response")
|
logger.debug("Finished streaming response")
|
||||||
return
|
return
|
||||||
if not connection_alive or not continue_stream:
|
if not connection_alive or not continue_stream:
|
||||||
|
|
|
@ -1770,6 +1770,7 @@ Manage your automations [here](/automations).
|
||||||
class MessageProcessor:
|
class MessageProcessor:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.references = {}
|
self.references = {}
|
||||||
|
self.usage = {}
|
||||||
self.raw_response = ""
|
self.raw_response = ""
|
||||||
|
|
||||||
def convert_message_chunk_to_json(self, raw_chunk: str) -> Dict[str, Any]:
|
def convert_message_chunk_to_json(self, raw_chunk: str) -> Dict[str, Any]:
|
||||||
|
@ -1793,6 +1794,8 @@ class MessageProcessor:
|
||||||
chunk_type = ChatEvent(chunk["type"])
|
chunk_type = ChatEvent(chunk["type"])
|
||||||
if chunk_type == ChatEvent.REFERENCES:
|
if chunk_type == ChatEvent.REFERENCES:
|
||||||
self.references = chunk["data"]
|
self.references = chunk["data"]
|
||||||
|
elif chunk_type == ChatEvent.USAGE:
|
||||||
|
self.usage = chunk["data"]
|
||||||
elif chunk_type == ChatEvent.MESSAGE:
|
elif chunk_type == ChatEvent.MESSAGE:
|
||||||
chunk_data = chunk["data"]
|
chunk_data = chunk["data"]
|
||||||
if isinstance(chunk_data, dict):
|
if isinstance(chunk_data, dict):
|
||||||
|
@ -1837,7 +1840,7 @@ async def read_chat_stream(response_iterator: AsyncGenerator[str, None]) -> Dict
|
||||||
if buffer:
|
if buffer:
|
||||||
processor.process_message_chunk(buffer)
|
processor.process_message_chunk(buffer)
|
||||||
|
|
||||||
return {"response": processor.raw_response, "references": processor.references}
|
return {"response": processor.raw_response, "references": processor.references, "usage": processor.usage}
|
||||||
|
|
||||||
|
|
||||||
def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False):
|
def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False):
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
app_root_directory = Path(__file__).parent.parent.parent
|
app_root_directory = Path(__file__).parent.parent.parent
|
||||||
web_directory = app_root_directory / "khoj/interface/web/"
|
web_directory = app_root_directory / "khoj/interface/web/"
|
||||||
|
@ -31,3 +32,19 @@ default_config = {
|
||||||
"image": {"encoder": "sentence-transformers/clip-ViT-B-32", "model_directory": "~/.khoj/search/image/"},
|
"image": {"encoder": "sentence-transformers/clip-ViT-B-32", "model_directory": "~/.khoj/search/image/"},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
model_to_cost: Dict[str, Dict[str, float]] = {
|
||||||
|
# OpenAI Pricing: https://openai.com/api/pricing/
|
||||||
|
"gpt-4o": {"input": 2.50, "output": 10.00},
|
||||||
|
"gpt-4o-mini": {"input": 0.15, "output": 0.60},
|
||||||
|
"o1-preview": {"input": 15.0, "output": 60.00},
|
||||||
|
"o1-mini": {"input": 3.0, "output": 12.0},
|
||||||
|
# Gemini Pricing: https://ai.google.dev/pricing
|
||||||
|
"gemini-1.5-flash": {"input": 0.075, "output": 0.30},
|
||||||
|
"gemini-1.5-flash-002": {"input": 0.075, "output": 0.30},
|
||||||
|
"gemini-1.5-pro": {"input": 1.25, "output": 5.00},
|
||||||
|
"gemini-1.5-pro-002": {"input": 1.25, "output": 5.00},
|
||||||
|
# Anthropic Pricing: https://www.anthropic.com/pricing#anthropic-api_
|
||||||
|
"claude-3-5-sonnet-20241022": {"input": 3.0, "output": 15.0},
|
||||||
|
"claude-3-5-haiku-20241022": {"input": 1.0, "output": 5.0},
|
||||||
|
}
|
||||||
|
|
|
@ -540,3 +540,27 @@ def get_country_code_from_timezone(tz: str) -> str:
|
||||||
def get_country_name_from_timezone(tz: str) -> str:
|
def get_country_name_from_timezone(tz: str) -> str:
|
||||||
"""Get country name from timezone"""
|
"""Get country name from timezone"""
|
||||||
return country_names.get(get_country_code_from_timezone(tz), "United States")
|
return country_names.get(get_country_code_from_timezone(tz), "United States")
|
||||||
|
|
||||||
|
|
||||||
|
def get_cost_of_chat_message(model_name: str, input_tokens: int = 0, output_tokens: int = 0, prev_cost: float = 0.0):
|
||||||
|
"""
|
||||||
|
Calculate cost of chat message based on input and output tokens
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Calculate cost of input and output tokens. Costs are per million tokens
|
||||||
|
input_cost = constants.model_to_cost.get(model_name, {}).get("input", 0) * (input_tokens / 1e6)
|
||||||
|
output_cost = constants.model_to_cost.get(model_name, {}).get("output", 0) * (output_tokens / 1e6)
|
||||||
|
|
||||||
|
return input_cost + output_cost + prev_cost
|
||||||
|
|
||||||
|
|
||||||
|
def get_chat_usage_metrics(model_name: str, input_tokens: int = 0, output_tokens: int = 0, usage: dict = {}):
|
||||||
|
"""
|
||||||
|
Get usage metrics for chat message based on input and output tokens
|
||||||
|
"""
|
||||||
|
prev_usage = usage or {"input_tokens": 0, "output_tokens": 0, "cost": 0.0}
|
||||||
|
return {
|
||||||
|
"input_tokens": prev_usage["input_tokens"] + input_tokens,
|
||||||
|
"output_tokens": prev_usage["output_tokens"] + output_tokens,
|
||||||
|
"cost": get_cost_of_chat_message(model_name, input_tokens, output_tokens, prev_cost=prev_usage["cost"]),
|
||||||
|
}
|
||||||
|
|
|
@ -6,13 +6,15 @@ import os
|
||||||
import time
|
import time
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
|
from textwrap import dedent
|
||||||
|
from threading import Lock
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import requests
|
import requests
|
||||||
from datasets import Dataset, load_dataset
|
from datasets import Dataset, load_dataset
|
||||||
|
|
||||||
from khoj.utils.helpers import is_none_or_empty, timer
|
from khoj.utils.helpers import get_cost_of_chat_message, is_none_or_empty, timer
|
||||||
|
|
||||||
# Configure root logger
|
# Configure root logger
|
||||||
logging.basicConfig(level=logging.INFO, format="%(message)s")
|
logging.basicConfig(level=logging.INFO, format="%(message)s")
|
||||||
|
@ -38,6 +40,28 @@ BATCH_SIZE = int(
|
||||||
SLEEP_SECONDS = 3 if KHOJ_MODE == "general" else 1 # Sleep between API calls to avoid rate limiting
|
SLEEP_SECONDS = 3 if KHOJ_MODE == "general" else 1 # Sleep between API calls to avoid rate limiting
|
||||||
|
|
||||||
|
|
||||||
|
class Counter:
|
||||||
|
"""Thread-safe counter for tracking metrics"""
|
||||||
|
|
||||||
|
def __init__(self, value=0.0):
|
||||||
|
self.value = value
|
||||||
|
self.lock = Lock()
|
||||||
|
|
||||||
|
def add(self, amount):
|
||||||
|
with self.lock:
|
||||||
|
self.value += amount
|
||||||
|
|
||||||
|
def get(self):
|
||||||
|
with self.lock:
|
||||||
|
return self.value
|
||||||
|
|
||||||
|
|
||||||
|
# Track running metrics while evaluating
|
||||||
|
running_cost = Counter()
|
||||||
|
running_true_count = Counter(0)
|
||||||
|
running_false_count = Counter(0)
|
||||||
|
|
||||||
|
|
||||||
def load_frames_dataset():
|
def load_frames_dataset():
|
||||||
"""
|
"""
|
||||||
Load the Google FRAMES benchmark dataset from HuggingFace
|
Load the Google FRAMES benchmark dataset from HuggingFace
|
||||||
|
@ -104,25 +128,31 @@ def load_simpleqa_dataset():
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def get_agent_response(prompt: str) -> str:
|
def get_agent_response(prompt: str) -> Dict[str, Any]:
|
||||||
"""Get response from the Khoj API"""
|
"""Get response from the Khoj API"""
|
||||||
|
# Set headers
|
||||||
|
headers = {"Content-Type": "application/json"}
|
||||||
|
if not is_none_or_empty(KHOJ_API_KEY):
|
||||||
|
headers["Authorization"] = f"Bearer {KHOJ_API_KEY}"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
KHOJ_CHAT_API_URL,
|
KHOJ_CHAT_API_URL,
|
||||||
headers={"Content-Type": "application/json", "Authorization": f"Bearer {KHOJ_API_KEY}"},
|
headers=headers,
|
||||||
json={
|
json={
|
||||||
"q": prompt,
|
"q": prompt,
|
||||||
"create_new": True,
|
"create_new": True,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
return response.json().get("response", "")
|
response_json = response.json()
|
||||||
|
return {"response": response_json.get("response", ""), "usage": response_json.get("usage", {})}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error getting agent response: {e}")
|
logger.error(f"Error getting agent response: {e}")
|
||||||
return ""
|
return {"response": "", "usage": {}}
|
||||||
|
|
||||||
|
|
||||||
def evaluate_response(query: str, agent_response: str, ground_truth: str) -> Dict[str, Any]:
|
def evaluate_response(query: str, agent_response: str, ground_truth: str) -> tuple[bool | None, str, float]:
|
||||||
"""Evaluate Khoj response against benchmark ground truth using Gemini"""
|
"""Evaluate Khoj response against benchmark ground truth using Gemini"""
|
||||||
evaluation_prompt = f"""
|
evaluation_prompt = f"""
|
||||||
Compare the following agent response with the ground truth answer.
|
Compare the following agent response with the ground truth answer.
|
||||||
|
@ -147,10 +177,16 @@ def evaluate_response(query: str, agent_response: str, ground_truth: str) -> Dic
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
response_json = response.json()
|
||||||
|
|
||||||
|
# Update cost of evaluation
|
||||||
|
input_tokens = response_json["usageMetadata"]["promptTokenCount"]
|
||||||
|
ouput_tokens = response_json["usageMetadata"]["candidatesTokenCount"]
|
||||||
|
cost = get_cost_of_chat_message(GEMINI_EVAL_MODEL, input_tokens, ouput_tokens)
|
||||||
|
|
||||||
# Parse evaluation response
|
# Parse evaluation response
|
||||||
eval_response: dict[str, str] = json.loads(
|
eval_response: dict[str, str] = json.loads(
|
||||||
clean_json(response.json()["candidates"][0]["content"]["parts"][0]["text"])
|
clean_json(response_json["candidates"][0]["content"]["parts"][0]["text"])
|
||||||
)
|
)
|
||||||
decision = str(eval_response.get("decision", "")).upper() == "TRUE"
|
decision = str(eval_response.get("decision", "")).upper() == "TRUE"
|
||||||
explanation = eval_response.get("explanation", "")
|
explanation = eval_response.get("explanation", "")
|
||||||
|
@ -158,13 +194,14 @@ def evaluate_response(query: str, agent_response: str, ground_truth: str) -> Dic
|
||||||
if "503 Service Error" in explanation:
|
if "503 Service Error" in explanation:
|
||||||
decision = None
|
decision = None
|
||||||
# Extract decision and explanation from structured response
|
# Extract decision and explanation from structured response
|
||||||
return decision, explanation
|
return decision, explanation, cost
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in evaluation: {e}")
|
logger.error(f"Error in evaluation: {e}")
|
||||||
return None, f"Evaluation failed: {str(e)}"
|
return None, f"Evaluation failed: {str(e)}", 0.0
|
||||||
|
|
||||||
|
|
||||||
def process_batch(batch, batch_start, results, dataset_length):
|
def process_batch(batch, batch_start, results, dataset_length):
|
||||||
|
global running_cost
|
||||||
for idx, (prompt, answer, reasoning_type) in enumerate(batch):
|
for idx, (prompt, answer, reasoning_type) in enumerate(batch):
|
||||||
current_index = batch_start + idx
|
current_index = batch_start + idx
|
||||||
logger.info(f"Processing example: {current_index}/{dataset_length}")
|
logger.info(f"Processing example: {current_index}/{dataset_length}")
|
||||||
|
@ -173,14 +210,16 @@ def process_batch(batch, batch_start, results, dataset_length):
|
||||||
prompt = f"/{KHOJ_MODE} {prompt}" if KHOJ_MODE and not prompt.startswith(f"/{KHOJ_MODE}") else prompt
|
prompt = f"/{KHOJ_MODE} {prompt}" if KHOJ_MODE and not prompt.startswith(f"/{KHOJ_MODE}") else prompt
|
||||||
|
|
||||||
# Get agent response
|
# Get agent response
|
||||||
agent_response = get_agent_response(prompt)
|
response = get_agent_response(prompt)
|
||||||
|
agent_response = response["response"]
|
||||||
|
agent_usage = response["usage"]
|
||||||
|
|
||||||
# Evaluate response
|
# Evaluate response
|
||||||
if is_none_or_empty(agent_response):
|
if is_none_or_empty(agent_response):
|
||||||
decision = None
|
decision = None
|
||||||
explanation = "Agent response is empty. This maybe due to a service error."
|
explanation = "Agent response is empty. This maybe due to a service error."
|
||||||
else:
|
else:
|
||||||
decision, explanation = evaluate_response(prompt, agent_response, answer)
|
decision, explanation, eval_cost = evaluate_response(prompt, agent_response, answer)
|
||||||
|
|
||||||
# Store results
|
# Store results
|
||||||
results.append(
|
results.append(
|
||||||
|
@ -192,17 +231,38 @@ def process_batch(batch, batch_start, results, dataset_length):
|
||||||
"evaluation_decision": decision,
|
"evaluation_decision": decision,
|
||||||
"evaluation_explanation": explanation,
|
"evaluation_explanation": explanation,
|
||||||
"reasoning_type": reasoning_type,
|
"reasoning_type": reasoning_type,
|
||||||
|
"usage": agent_usage,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Log results
|
# Update running cost
|
||||||
|
query_cost = float(agent_usage.get("cost", 0.0))
|
||||||
|
running_cost.add(query_cost + eval_cost)
|
||||||
|
|
||||||
|
# Update running accuracy
|
||||||
|
running_accuracy = 0.0
|
||||||
|
if decision is not None:
|
||||||
|
running_true_count.add(1) if decision == True else running_false_count.add(1)
|
||||||
|
running_accuracy = running_true_count.get() / (running_true_count.get() + running_false_count.get())
|
||||||
|
|
||||||
|
## Log results
|
||||||
decision_color = {True: "green", None: "blue", False: "red"}[decision]
|
decision_color = {True: "green", None: "blue", False: "red"}[decision]
|
||||||
colored_decision = color_text(str(decision), decision_color)
|
colored_decision = color_text(str(decision), decision_color)
|
||||||
logger.info(
|
result_to_print = f"""
|
||||||
f"Decision: {colored_decision}\nQuestion: {prompt}\nExpected Answer: {answer}\nAgent Answer: {agent_response}\nExplanation: {explanation}\n"
|
---------
|
||||||
)
|
Decision: {colored_decision}
|
||||||
|
Accuracy: {running_accuracy:.2%}
|
||||||
|
Question: {prompt}
|
||||||
|
Expected Answer: {answer}
|
||||||
|
Agent Answer: {agent_response}
|
||||||
|
Explanation: {explanation}
|
||||||
|
Cost: ${running_cost.get():.5f} (Query: ${query_cost:.5f}, Eval: ${eval_cost:.5f})
|
||||||
|
---------
|
||||||
|
"""
|
||||||
|
logger.info(dedent(result_to_print).lstrip())
|
||||||
|
|
||||||
time.sleep(SLEEP_SECONDS) # Rate limiting
|
# Sleep between API calls to avoid rate limiting
|
||||||
|
time.sleep(SLEEP_SECONDS)
|
||||||
|
|
||||||
|
|
||||||
def color_text(text, color):
|
def color_text(text, color):
|
||||||
|
@ -281,17 +341,18 @@ def main():
|
||||||
lambda x: (x == True).mean()
|
lambda x: (x == True).mean()
|
||||||
)
|
)
|
||||||
|
|
||||||
# Print summary
|
# Collect summary
|
||||||
colored_accuracy = color_text(f"{accuracy:.2%}", "blue")
|
colored_accuracy = color_text(f"{accuracy:.2%}", "blue")
|
||||||
logger.info(f"\nOverall Accuracy: {colored_accuracy}")
|
colored_accuracy_str = f"Overall Accuracy: {colored_accuracy} on {args.dataset.title()} dataset."
|
||||||
logger.info(f"\nAccuracy by Reasoning Type:\n{reasoning_type_accuracy}")
|
accuracy_str = f"Overall Accuracy: {accuracy:.2%} on {args.dataset}."
|
||||||
|
accuracy_by_reasoning = f"Accuracy by Reasoning Type:\n{reasoning_type_accuracy}"
|
||||||
# Save summary to file
|
cost = f"Total Cost: ${running_cost.get():.5f}."
|
||||||
sample_type = f"Sampling Type: {SAMPLE_SIZE} samples." if SAMPLE_SIZE else "Whole dataset."
|
sample_type = f"Sampling Type: {SAMPLE_SIZE} samples." if SAMPLE_SIZE else "Whole dataset."
|
||||||
sample_type += " Randomized." if RANDOMIZE else ""
|
sample_type += " Randomized." if RANDOMIZE else ""
|
||||||
summary = (
|
logger.info(f"\n{colored_accuracy_str}\n\n{accuracy_by_reasoning}\n\n{cost}\n\n{sample_type}\n")
|
||||||
f"Overall Accuracy: {accuracy:.2%}\n\nAccuracy by Reasoning Type:\n{reasoning_type_accuracy}\n\n{sample_type}\n"
|
|
||||||
)
|
# Save summary to file
|
||||||
|
summary = f"{accuracy_str}\n\n{accuracy_by_reasoning}\n\n{cost}\n\n{sample_type}\n"
|
||||||
summary_file = args.output.replace(".csv", ".txt") if args.output else None
|
summary_file = args.output.replace(".csv", ".txt") if args.output else None
|
||||||
summary_file = (
|
summary_file = (
|
||||||
summary_file or f"{args.dataset}_evaluation_summary_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.txt"
|
summary_file or f"{args.dataset}_evaluation_summary_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.txt"
|
||||||
|
|
Loading…
Reference in a new issue