diff --git a/src/interface/obsidian/src/chat_view.ts b/src/interface/obsidian/src/chat_view.ts index 408ce3a1..552a54bd 100644 --- a/src/interface/obsidian/src/chat_view.ts +++ b/src/interface/obsidian/src/chat_view.ts @@ -945,7 +945,7 @@ export class KhojChatView extends KhojPaneView { console.log("Started streaming", new Date()); } else if (chunk.type === 'end_llm_response') { console.log("Stopped streaming", new Date()); - + } else if (chunk.type === 'end_response') { // Automatically respond with voice if the subscribed user has sent voice message if (this.chatMessageState.isVoice && this.setting.userInfo?.is_active) this.textToSpeech(this.chatMessageState.rawResponse); diff --git a/src/interface/web/app/common/chatFunctions.ts b/src/interface/web/app/common/chatFunctions.ts index aeb74ee8..6585b4c9 100644 --- a/src/interface/web/app/common/chatFunctions.ts +++ b/src/interface/web/app/common/chatFunctions.ts @@ -133,7 +133,7 @@ export function processMessageChunk( console.log(`Started streaming: ${new Date()}`); } else if (chunk.type === "end_llm_response") { console.log(`Completed streaming: ${new Date()}`); - + } else if (chunk.type === "end_response") { // Append any references after all the data has been streamed if (codeContext) currentMessage.codeContext = codeContext; if (onlineContext) currentMessage.onlineContext = onlineContext; diff --git a/src/khoj/processor/conversation/anthropic/utils.py b/src/khoj/processor/conversation/anthropic/utils.py index cdce63c6..4f6b68c2 100644 --- a/src/khoj/processor/conversation/anthropic/utils.py +++ b/src/khoj/processor/conversation/anthropic/utils.py @@ -18,7 +18,7 @@ from khoj.processor.conversation.utils import ( get_image_from_url, ) from khoj.utils import state -from khoj.utils.helpers import in_debug_mode, is_none_or_empty +from khoj.utils.helpers import get_chat_usage_metrics, in_debug_mode, is_none_or_empty logger = logging.getLogger(__name__) @@ -59,6 +59,7 @@ def anthropic_completion_with_backoff( aggregated_response = "{" if response_type == "json_object" else "" max_tokens = max_tokens or DEFAULT_MAX_TOKENS_ANTHROPIC + final_message = None model_kwargs = model_kwargs or dict() if system_prompt: model_kwargs["system"] = system_prompt @@ -73,6 +74,12 @@ def anthropic_completion_with_backoff( ) as stream: for text in stream.text_stream: aggregated_response += text + final_message = stream.get_final_message() + + # Calculate cost of chat + input_tokens = final_message.usage.input_tokens + output_tokens = final_message.usage.output_tokens + tracer["usage"] = get_chat_usage_metrics(model_name, input_tokens, output_tokens, tracer.get("usage")) # Save conversation trace tracer["chat_model"] = model_name @@ -126,6 +133,7 @@ def anthropic_llm_thread( ] aggregated_response = "" + final_message = None with client.messages.stream( messages=formatted_messages, model=model_name, # type: ignore @@ -138,6 +146,12 @@ def anthropic_llm_thread( for text in stream.text_stream: aggregated_response += text g.send(text) + final_message = stream.get_final_message() + + # Calculate cost of chat + input_tokens = final_message.usage.input_tokens + output_tokens = final_message.usage.output_tokens + tracer["usage"] = get_chat_usage_metrics(model_name, input_tokens, output_tokens, tracer.get("usage")) # Save conversation trace tracer["chat_model"] = model_name diff --git a/src/khoj/processor/conversation/google/utils.py b/src/khoj/processor/conversation/google/utils.py index 84ad607e..eb9b21b0 100644 --- a/src/khoj/processor/conversation/google/utils.py +++ b/src/khoj/processor/conversation/google/utils.py @@ -25,7 +25,7 @@ from khoj.processor.conversation.utils import ( get_image_from_url, ) from khoj.utils import state -from khoj.utils.helpers import in_debug_mode, is_none_or_empty +from khoj.utils.helpers import get_chat_usage_metrics, in_debug_mode, is_none_or_empty logger = logging.getLogger(__name__) @@ -68,6 +68,7 @@ def gemini_completion_with_backoff( response = chat_session.send_message(formatted_messages[-1]["parts"]) response_text = response.text except StopCandidateException as e: + response = None response_text, _ = handle_gemini_response(e.args) # Respond with reason for stopping logger.warning( @@ -75,6 +76,11 @@ def gemini_completion_with_backoff( + f"Last Message by {messages[-1].role}: {messages[-1].content}" ) + # Aggregate cost of chat + input_tokens = response.usage_metadata.prompt_token_count if response else 0 + output_tokens = response.usage_metadata.candidates_token_count if response else 0 + tracer["usage"] = get_chat_usage_metrics(model_name, input_tokens, output_tokens, tracer.get("usage")) + # Save conversation trace tracer["chat_model"] = model_name tracer["temperature"] = temperature @@ -146,6 +152,11 @@ def gemini_llm_thread( if stopped: raise StopCandidateException(message) + # Calculate cost of chat + input_tokens = chunk.usage_metadata.prompt_token_count + output_tokens = chunk.usage_metadata.candidates_token_count + tracer["usage"] = get_chat_usage_metrics(model_name, input_tokens, output_tokens, tracer.get("usage")) + # Save conversation trace tracer["chat_model"] = model_name tracer["temperature"] = temperature diff --git a/src/khoj/processor/conversation/openai/utils.py b/src/khoj/processor/conversation/openai/utils.py index 36ebc679..3ffaf753 100644 --- a/src/khoj/processor/conversation/openai/utils.py +++ b/src/khoj/processor/conversation/openai/utils.py @@ -4,6 +4,8 @@ from threading import Thread from typing import Dict import openai +from openai.types.chat.chat_completion import ChatCompletion +from openai.types.chat.chat_completion_chunk import ChatCompletionChunk from tenacity import ( before_sleep_log, retry, @@ -18,7 +20,7 @@ from khoj.processor.conversation.utils import ( commit_conversation_trace, ) from khoj.utils import state -from khoj.utils.helpers import in_debug_mode +from khoj.utils.helpers import get_chat_usage_metrics, in_debug_mode logger = logging.getLogger(__name__) @@ -64,27 +66,34 @@ def completion_with_backoff( if os.getenv("KHOJ_LLM_SEED"): model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED")) - chat = client.chat.completions.create( - stream=stream, + chat: ChatCompletion | openai.Stream[ChatCompletionChunk] = client.chat.completions.create( messages=formatted_messages, # type: ignore model=model, # type: ignore + stream=stream, + stream_options={"include_usage": True} if stream else {}, temperature=temperature, timeout=20, **(model_kwargs or dict()), ) - if not stream: - return chat.choices[0].message.content - aggregated_response = "" - for chunk in chat: - if len(chunk.choices) == 0: - continue - delta_chunk = chunk.choices[0].delta # type: ignore - if isinstance(delta_chunk, str): - aggregated_response += delta_chunk - elif delta_chunk.content: - aggregated_response += delta_chunk.content + if not stream: + chunk = chat + aggregated_response = chunk.choices[0].message.content + else: + for chunk in chat: + if len(chunk.choices) == 0: + continue + delta_chunk = chunk.choices[0].delta # type: ignore + if isinstance(delta_chunk, str): + aggregated_response += delta_chunk + elif delta_chunk.content: + aggregated_response += delta_chunk.content + + # Calculate cost of chat + input_tokens = chunk.usage.prompt_tokens if hasattr(chunk, "usage") and chunk.usage else 0 + output_tokens = chunk.usage.completion_tokens if hasattr(chunk, "usage") and chunk.usage else 0 + tracer["usage"] = get_chat_usage_metrics(model, input_tokens, output_tokens, tracer.get("usage")) # Save conversation trace tracer["chat_model"] = model @@ -164,10 +173,11 @@ def llm_thread( if os.getenv("KHOJ_LLM_SEED"): model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED")) - chat = client.chat.completions.create( - stream=stream, + chat: ChatCompletion | openai.Stream[ChatCompletionChunk] = client.chat.completions.create( messages=formatted_messages, model=model_name, # type: ignore + stream=stream, + stream_options={"include_usage": True} if stream else {}, temperature=temperature, timeout=20, **(model_kwargs or dict()), @@ -175,7 +185,8 @@ def llm_thread( aggregated_response = "" if not stream: - aggregated_response = chat.choices[0].message.content + chunk = chat + aggregated_response = chunk.choices[0].message.content g.send(aggregated_response) else: for chunk in chat: @@ -191,6 +202,11 @@ def llm_thread( aggregated_response += text_chunk g.send(text_chunk) + # Calculate cost of chat + input_tokens = chunk.usage.prompt_tokens if hasattr(chunk, "usage") and chunk.usage else 0 + output_tokens = chunk.usage.completion_tokens if hasattr(chunk, "usage") and chunk.usage else 0 + tracer["usage"] = get_chat_usage_metrics(model_name, input_tokens, output_tokens, tracer.get("usage")) + # Save conversation trace tracer["chat_model"] = model_name tracer["temperature"] = temperature diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 91ba3c72..cd55c986 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -5,7 +5,6 @@ import math import mimetypes import os import queue -import re import uuid from dataclasses import dataclass from datetime import datetime @@ -57,7 +56,7 @@ model_to_prompt_size = { "gemini-1.5-flash": 20000, "gemini-1.5-pro": 20000, # Anthropic Models - "claude-3-5-sonnet-20240620": 20000, + "claude-3-5-sonnet-20241022": 20000, "claude-3-5-haiku-20241022": 20000, # Offline Models "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF": 20000, @@ -213,6 +212,8 @@ class ChatEvent(Enum): REFERENCES = "references" STATUS = "status" METADATA = "metadata" + USAGE = "usage" + END_RESPONSE = "end_response" def message_to_log( diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index a9086dd0..4c6fd7c4 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -667,27 +667,37 @@ async def chat( finally: yield event_delimiter - async def send_llm_response(response: str): + async def send_llm_response(response: str, usage: dict = None): + # Send Chat Response async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""): yield result async for result in send_event(ChatEvent.MESSAGE, response): yield result async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""): yield result + # Send Usage Metadata once llm interactions are complete + if usage: + async for event in send_event(ChatEvent.USAGE, usage): + yield event + async for result in send_event(ChatEvent.END_RESPONSE, ""): + yield result def collect_telemetry(): # Gather chat response telemetry nonlocal chat_metadata latency = time.perf_counter() - start_time cmd_set = set([cmd.value for cmd in conversation_commands]) + cost = (tracer.get("usage", {}) or {}).get("cost", 0) chat_metadata = chat_metadata or {} chat_metadata["conversation_command"] = cmd_set chat_metadata["agent"] = conversation.agent.slug if conversation.agent else None chat_metadata["latency"] = f"{latency:.3f}" chat_metadata["ttft_latency"] = f"{ttft:.3f}" + chat_metadata["usage"] = tracer.get("usage") logger.info(f"Chat response time to first token: {ttft:.3f} seconds") logger.info(f"Chat response total time: {latency:.3f} seconds") + logger.info(f"Chat response cost: ${cost:.5f}") update_telemetry_state( request=request, telemetry_type="api", @@ -699,7 +709,7 @@ async def chat( ) if is_query_empty(q): - async for result in send_llm_response("Please ask your query to get started."): + async for result in send_llm_response("Please ask your query to get started.", tracer.get("usage")): yield result return @@ -713,7 +723,7 @@ async def chat( create_new=body.create_new, ) if not conversation: - async for result in send_llm_response(f"Conversation {conversation_id} not found"): + async for result in send_llm_response(f"Conversation {conversation_id} not found", tracer.get("usage")): yield result return conversation_id = conversation.id @@ -777,7 +787,7 @@ async def chat( await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd) q = q.replace(f"/{cmd.value}", "").strip() except HTTPException as e: - async for result in send_llm_response(str(e.detail)): + async for result in send_llm_response(str(e.detail), tracer.get("usage")): yield result return @@ -834,7 +844,7 @@ async def chat( agent_has_entries = await EntryAdapters.aagent_has_entries(agent) if len(file_filters) == 0 and not agent_has_entries: response_log = "No files selected for summarization. Please add files using the section on the left." - async for result in send_llm_response(response_log): + async for result in send_llm_response(response_log, tracer.get("usage")): yield result else: async for response in generate_summary_from_files( @@ -853,7 +863,7 @@ async def chat( else: if isinstance(response, str): response_log = response - async for result in send_llm_response(response): + async for result in send_llm_response(response, tracer.get("usage")): yield result await sync_to_async(save_to_conversation_log)( @@ -880,7 +890,7 @@ async def chat( conversation_config = await ConversationAdapters.aget_default_conversation_config(user) model_type = conversation_config.model_type formatted_help = help_message.format(model=model_type, version=state.khoj_version, device=get_device()) - async for result in send_llm_response(formatted_help): + async for result in send_llm_response(formatted_help, tracer.get("usage")): yield result return # Adding specification to search online specifically on khoj.dev pages. @@ -895,7 +905,7 @@ async def chat( except Exception as e: logger.error(f"Error scheduling task {q} for {user.email}: {e}") error_message = f"Unable to create automation. Ensure the automation doesn't already exist." - async for result in send_llm_response(error_message): + async for result in send_llm_response(error_message, tracer.get("usage")): yield result return @@ -916,7 +926,7 @@ async def chat( raw_query_files=raw_query_files, tracer=tracer, ) - async for result in send_llm_response(llm_response): + async for result in send_llm_response(llm_response, tracer.get("usage")): yield result return @@ -963,7 +973,7 @@ async def chat( yield result if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user): - async for result in send_llm_response(f"{no_entries_found.format()}"): + async for result in send_llm_response(f"{no_entries_found.format()}", tracer.get("usage")): yield result return @@ -1105,7 +1115,7 @@ async def chat( "detail": improved_image_prompt, "image": None, } - async for result in send_llm_response(json.dumps(content_obj)): + async for result in send_llm_response(json.dumps(content_obj), tracer.get("usage")): yield result return @@ -1132,7 +1142,7 @@ async def chat( "inferredQueries": [improved_image_prompt], "image": generated_image, } - async for result in send_llm_response(json.dumps(content_obj)): + async for result in send_llm_response(json.dumps(content_obj), tracer.get("usage")): yield result return @@ -1166,7 +1176,7 @@ async def chat( diagram_description = excalidraw_diagram_description else: error_message = "Failed to generate diagram. Please try again later." - async for result in send_llm_response(error_message): + async for result in send_llm_response(error_message, tracer.get("usage")): yield result await sync_to_async(save_to_conversation_log)( @@ -1213,7 +1223,7 @@ async def chat( tracer=tracer, ) - async for result in send_llm_response(json.dumps(content_obj)): + async for result in send_llm_response(json.dumps(content_obj), tracer.get("usage")): yield result return @@ -1252,6 +1262,11 @@ async def chat( if item is None: async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""): yield result + # Send Usage Metadata once llm interactions are complete + async for event in send_event(ChatEvent.USAGE, tracer.get("usage")): + yield event + async for result in send_event(ChatEvent.END_RESPONSE, ""): + yield result logger.debug("Finished streaming response") return if not connection_alive or not continue_stream: diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index b011106e..4fd30de9 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -1770,6 +1770,7 @@ Manage your automations [here](/automations). class MessageProcessor: def __init__(self): self.references = {} + self.usage = {} self.raw_response = "" def convert_message_chunk_to_json(self, raw_chunk: str) -> Dict[str, Any]: @@ -1793,6 +1794,8 @@ class MessageProcessor: chunk_type = ChatEvent(chunk["type"]) if chunk_type == ChatEvent.REFERENCES: self.references = chunk["data"] + elif chunk_type == ChatEvent.USAGE: + self.usage = chunk["data"] elif chunk_type == ChatEvent.MESSAGE: chunk_data = chunk["data"] if isinstance(chunk_data, dict): @@ -1837,7 +1840,7 @@ async def read_chat_stream(response_iterator: AsyncGenerator[str, None]) -> Dict if buffer: processor.process_message_chunk(buffer) - return {"response": processor.raw_response, "references": processor.references} + return {"response": processor.raw_response, "references": processor.references, "usage": processor.usage} def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False): diff --git a/src/khoj/utils/constants.py b/src/khoj/utils/constants.py index 591a0f08..40320373 100644 --- a/src/khoj/utils/constants.py +++ b/src/khoj/utils/constants.py @@ -1,4 +1,5 @@ from pathlib import Path +from typing import Dict app_root_directory = Path(__file__).parent.parent.parent web_directory = app_root_directory / "khoj/interface/web/" @@ -31,3 +32,19 @@ default_config = { "image": {"encoder": "sentence-transformers/clip-ViT-B-32", "model_directory": "~/.khoj/search/image/"}, }, } + +model_to_cost: Dict[str, Dict[str, float]] = { + # OpenAI Pricing: https://openai.com/api/pricing/ + "gpt-4o": {"input": 2.50, "output": 10.00}, + "gpt-4o-mini": {"input": 0.15, "output": 0.60}, + "o1-preview": {"input": 15.0, "output": 60.00}, + "o1-mini": {"input": 3.0, "output": 12.0}, + # Gemini Pricing: https://ai.google.dev/pricing + "gemini-1.5-flash": {"input": 0.075, "output": 0.30}, + "gemini-1.5-flash-002": {"input": 0.075, "output": 0.30}, + "gemini-1.5-pro": {"input": 1.25, "output": 5.00}, + "gemini-1.5-pro-002": {"input": 1.25, "output": 5.00}, + # Anthropic Pricing: https://www.anthropic.com/pricing#anthropic-api_ + "claude-3-5-sonnet-20241022": {"input": 3.0, "output": 15.0}, + "claude-3-5-haiku-20241022": {"input": 1.0, "output": 5.0}, +} diff --git a/src/khoj/utils/helpers.py b/src/khoj/utils/helpers.py index d1617d79..02cd7a92 100644 --- a/src/khoj/utils/helpers.py +++ b/src/khoj/utils/helpers.py @@ -540,3 +540,27 @@ def get_country_code_from_timezone(tz: str) -> str: def get_country_name_from_timezone(tz: str) -> str: """Get country name from timezone""" return country_names.get(get_country_code_from_timezone(tz), "United States") + + +def get_cost_of_chat_message(model_name: str, input_tokens: int = 0, output_tokens: int = 0, prev_cost: float = 0.0): + """ + Calculate cost of chat message based on input and output tokens + """ + + # Calculate cost of input and output tokens. Costs are per million tokens + input_cost = constants.model_to_cost.get(model_name, {}).get("input", 0) * (input_tokens / 1e6) + output_cost = constants.model_to_cost.get(model_name, {}).get("output", 0) * (output_tokens / 1e6) + + return input_cost + output_cost + prev_cost + + +def get_chat_usage_metrics(model_name: str, input_tokens: int = 0, output_tokens: int = 0, usage: dict = {}): + """ + Get usage metrics for chat message based on input and output tokens + """ + prev_usage = usage or {"input_tokens": 0, "output_tokens": 0, "cost": 0.0} + return { + "input_tokens": prev_usage["input_tokens"] + input_tokens, + "output_tokens": prev_usage["output_tokens"] + output_tokens, + "cost": get_cost_of_chat_message(model_name, input_tokens, output_tokens, prev_cost=prev_usage["cost"]), + }