mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 15:38:55 +01:00
Track, return cost and usage metrics in chat api response
- Track input, output token usage and cost for interactions via chat api with openai, anthropic and google chat models - Get usage metadata from OpenAI using stream_options - Handle openai proxies that do not support passing usage in response - Add new usage, end response events returned by chat api. - This can be optionally consumed by clients at a later point - Update streaming clients to mark message as completed after new end response event, not after end llm response event - Ensure usage data from final response generation step is included - Pass usage data after llm response complete. This allows gathering token usage and cost for the final response generation step across streaming and non-streaming modes
This commit is contained in:
parent
7bdc9590dd
commit
c53c3db96b
10 changed files with 139 additions and 38 deletions
|
@ -945,7 +945,7 @@ export class KhojChatView extends KhojPaneView {
|
||||||
console.log("Started streaming", new Date());
|
console.log("Started streaming", new Date());
|
||||||
} else if (chunk.type === 'end_llm_response') {
|
} else if (chunk.type === 'end_llm_response') {
|
||||||
console.log("Stopped streaming", new Date());
|
console.log("Stopped streaming", new Date());
|
||||||
|
} else if (chunk.type === 'end_response') {
|
||||||
// Automatically respond with voice if the subscribed user has sent voice message
|
// Automatically respond with voice if the subscribed user has sent voice message
|
||||||
if (this.chatMessageState.isVoice && this.setting.userInfo?.is_active)
|
if (this.chatMessageState.isVoice && this.setting.userInfo?.is_active)
|
||||||
this.textToSpeech(this.chatMessageState.rawResponse);
|
this.textToSpeech(this.chatMessageState.rawResponse);
|
||||||
|
|
|
@ -133,7 +133,7 @@ export function processMessageChunk(
|
||||||
console.log(`Started streaming: ${new Date()}`);
|
console.log(`Started streaming: ${new Date()}`);
|
||||||
} else if (chunk.type === "end_llm_response") {
|
} else if (chunk.type === "end_llm_response") {
|
||||||
console.log(`Completed streaming: ${new Date()}`);
|
console.log(`Completed streaming: ${new Date()}`);
|
||||||
|
} else if (chunk.type === "end_response") {
|
||||||
// Append any references after all the data has been streamed
|
// Append any references after all the data has been streamed
|
||||||
if (codeContext) currentMessage.codeContext = codeContext;
|
if (codeContext) currentMessage.codeContext = codeContext;
|
||||||
if (onlineContext) currentMessage.onlineContext = onlineContext;
|
if (onlineContext) currentMessage.onlineContext = onlineContext;
|
||||||
|
|
|
@ -18,7 +18,7 @@ from khoj.processor.conversation.utils import (
|
||||||
get_image_from_url,
|
get_image_from_url,
|
||||||
)
|
)
|
||||||
from khoj.utils import state
|
from khoj.utils import state
|
||||||
from khoj.utils.helpers import in_debug_mode, is_none_or_empty
|
from khoj.utils.helpers import get_chat_usage_metrics, in_debug_mode, is_none_or_empty
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -59,6 +59,7 @@ def anthropic_completion_with_backoff(
|
||||||
aggregated_response = "{" if response_type == "json_object" else ""
|
aggregated_response = "{" if response_type == "json_object" else ""
|
||||||
max_tokens = max_tokens or DEFAULT_MAX_TOKENS_ANTHROPIC
|
max_tokens = max_tokens or DEFAULT_MAX_TOKENS_ANTHROPIC
|
||||||
|
|
||||||
|
final_message = None
|
||||||
model_kwargs = model_kwargs or dict()
|
model_kwargs = model_kwargs or dict()
|
||||||
if system_prompt:
|
if system_prompt:
|
||||||
model_kwargs["system"] = system_prompt
|
model_kwargs["system"] = system_prompt
|
||||||
|
@ -73,6 +74,12 @@ def anthropic_completion_with_backoff(
|
||||||
) as stream:
|
) as stream:
|
||||||
for text in stream.text_stream:
|
for text in stream.text_stream:
|
||||||
aggregated_response += text
|
aggregated_response += text
|
||||||
|
final_message = stream.get_final_message()
|
||||||
|
|
||||||
|
# Calculate cost of chat
|
||||||
|
input_tokens = final_message.usage.input_tokens
|
||||||
|
output_tokens = final_message.usage.output_tokens
|
||||||
|
tracer["usage"] = get_chat_usage_metrics(model_name, input_tokens, output_tokens, tracer.get("usage"))
|
||||||
|
|
||||||
# Save conversation trace
|
# Save conversation trace
|
||||||
tracer["chat_model"] = model_name
|
tracer["chat_model"] = model_name
|
||||||
|
@ -126,6 +133,7 @@ def anthropic_llm_thread(
|
||||||
]
|
]
|
||||||
|
|
||||||
aggregated_response = ""
|
aggregated_response = ""
|
||||||
|
final_message = None
|
||||||
with client.messages.stream(
|
with client.messages.stream(
|
||||||
messages=formatted_messages,
|
messages=formatted_messages,
|
||||||
model=model_name, # type: ignore
|
model=model_name, # type: ignore
|
||||||
|
@ -138,6 +146,12 @@ def anthropic_llm_thread(
|
||||||
for text in stream.text_stream:
|
for text in stream.text_stream:
|
||||||
aggregated_response += text
|
aggregated_response += text
|
||||||
g.send(text)
|
g.send(text)
|
||||||
|
final_message = stream.get_final_message()
|
||||||
|
|
||||||
|
# Calculate cost of chat
|
||||||
|
input_tokens = final_message.usage.input_tokens
|
||||||
|
output_tokens = final_message.usage.output_tokens
|
||||||
|
tracer["usage"] = get_chat_usage_metrics(model_name, input_tokens, output_tokens, tracer.get("usage"))
|
||||||
|
|
||||||
# Save conversation trace
|
# Save conversation trace
|
||||||
tracer["chat_model"] = model_name
|
tracer["chat_model"] = model_name
|
||||||
|
|
|
@ -25,7 +25,7 @@ from khoj.processor.conversation.utils import (
|
||||||
get_image_from_url,
|
get_image_from_url,
|
||||||
)
|
)
|
||||||
from khoj.utils import state
|
from khoj.utils import state
|
||||||
from khoj.utils.helpers import in_debug_mode, is_none_or_empty
|
from khoj.utils.helpers import get_chat_usage_metrics, in_debug_mode, is_none_or_empty
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -68,6 +68,7 @@ def gemini_completion_with_backoff(
|
||||||
response = chat_session.send_message(formatted_messages[-1]["parts"])
|
response = chat_session.send_message(formatted_messages[-1]["parts"])
|
||||||
response_text = response.text
|
response_text = response.text
|
||||||
except StopCandidateException as e:
|
except StopCandidateException as e:
|
||||||
|
response = None
|
||||||
response_text, _ = handle_gemini_response(e.args)
|
response_text, _ = handle_gemini_response(e.args)
|
||||||
# Respond with reason for stopping
|
# Respond with reason for stopping
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
@ -75,6 +76,11 @@ def gemini_completion_with_backoff(
|
||||||
+ f"Last Message by {messages[-1].role}: {messages[-1].content}"
|
+ f"Last Message by {messages[-1].role}: {messages[-1].content}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Aggregate cost of chat
|
||||||
|
input_tokens = response.usage_metadata.prompt_token_count if response else 0
|
||||||
|
output_tokens = response.usage_metadata.candidates_token_count if response else 0
|
||||||
|
tracer["usage"] = get_chat_usage_metrics(model_name, input_tokens, output_tokens, tracer.get("usage"))
|
||||||
|
|
||||||
# Save conversation trace
|
# Save conversation trace
|
||||||
tracer["chat_model"] = model_name
|
tracer["chat_model"] = model_name
|
||||||
tracer["temperature"] = temperature
|
tracer["temperature"] = temperature
|
||||||
|
@ -146,6 +152,11 @@ def gemini_llm_thread(
|
||||||
if stopped:
|
if stopped:
|
||||||
raise StopCandidateException(message)
|
raise StopCandidateException(message)
|
||||||
|
|
||||||
|
# Calculate cost of chat
|
||||||
|
input_tokens = chunk.usage_metadata.prompt_token_count
|
||||||
|
output_tokens = chunk.usage_metadata.candidates_token_count
|
||||||
|
tracer["usage"] = get_chat_usage_metrics(model_name, input_tokens, output_tokens, tracer.get("usage"))
|
||||||
|
|
||||||
# Save conversation trace
|
# Save conversation trace
|
||||||
tracer["chat_model"] = model_name
|
tracer["chat_model"] = model_name
|
||||||
tracer["temperature"] = temperature
|
tracer["temperature"] = temperature
|
||||||
|
|
|
@ -4,6 +4,8 @@ from threading import Thread
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
|
from openai.types.chat.chat_completion import ChatCompletion
|
||||||
|
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||||
from tenacity import (
|
from tenacity import (
|
||||||
before_sleep_log,
|
before_sleep_log,
|
||||||
retry,
|
retry,
|
||||||
|
@ -18,7 +20,7 @@ from khoj.processor.conversation.utils import (
|
||||||
commit_conversation_trace,
|
commit_conversation_trace,
|
||||||
)
|
)
|
||||||
from khoj.utils import state
|
from khoj.utils import state
|
||||||
from khoj.utils.helpers import in_debug_mode
|
from khoj.utils.helpers import get_chat_usage_metrics, in_debug_mode
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -64,27 +66,34 @@ def completion_with_backoff(
|
||||||
if os.getenv("KHOJ_LLM_SEED"):
|
if os.getenv("KHOJ_LLM_SEED"):
|
||||||
model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED"))
|
model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED"))
|
||||||
|
|
||||||
chat = client.chat.completions.create(
|
chat: ChatCompletion | openai.Stream[ChatCompletionChunk] = client.chat.completions.create(
|
||||||
stream=stream,
|
|
||||||
messages=formatted_messages, # type: ignore
|
messages=formatted_messages, # type: ignore
|
||||||
model=model, # type: ignore
|
model=model, # type: ignore
|
||||||
|
stream=stream,
|
||||||
|
stream_options={"include_usage": True} if stream else {},
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
timeout=20,
|
timeout=20,
|
||||||
**(model_kwargs or dict()),
|
**(model_kwargs or dict()),
|
||||||
)
|
)
|
||||||
|
|
||||||
if not stream:
|
|
||||||
return chat.choices[0].message.content
|
|
||||||
|
|
||||||
aggregated_response = ""
|
aggregated_response = ""
|
||||||
for chunk in chat:
|
if not stream:
|
||||||
if len(chunk.choices) == 0:
|
chunk = chat
|
||||||
continue
|
aggregated_response = chunk.choices[0].message.content
|
||||||
delta_chunk = chunk.choices[0].delta # type: ignore
|
else:
|
||||||
if isinstance(delta_chunk, str):
|
for chunk in chat:
|
||||||
aggregated_response += delta_chunk
|
if len(chunk.choices) == 0:
|
||||||
elif delta_chunk.content:
|
continue
|
||||||
aggregated_response += delta_chunk.content
|
delta_chunk = chunk.choices[0].delta # type: ignore
|
||||||
|
if isinstance(delta_chunk, str):
|
||||||
|
aggregated_response += delta_chunk
|
||||||
|
elif delta_chunk.content:
|
||||||
|
aggregated_response += delta_chunk.content
|
||||||
|
|
||||||
|
# Calculate cost of chat
|
||||||
|
input_tokens = chunk.usage.prompt_tokens if hasattr(chunk, "usage") and chunk.usage else 0
|
||||||
|
output_tokens = chunk.usage.completion_tokens if hasattr(chunk, "usage") and chunk.usage else 0
|
||||||
|
tracer["usage"] = get_chat_usage_metrics(model, input_tokens, output_tokens, tracer.get("usage"))
|
||||||
|
|
||||||
# Save conversation trace
|
# Save conversation trace
|
||||||
tracer["chat_model"] = model
|
tracer["chat_model"] = model
|
||||||
|
@ -164,10 +173,11 @@ def llm_thread(
|
||||||
if os.getenv("KHOJ_LLM_SEED"):
|
if os.getenv("KHOJ_LLM_SEED"):
|
||||||
model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED"))
|
model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED"))
|
||||||
|
|
||||||
chat = client.chat.completions.create(
|
chat: ChatCompletion | openai.Stream[ChatCompletionChunk] = client.chat.completions.create(
|
||||||
stream=stream,
|
|
||||||
messages=formatted_messages,
|
messages=formatted_messages,
|
||||||
model=model_name, # type: ignore
|
model=model_name, # type: ignore
|
||||||
|
stream=stream,
|
||||||
|
stream_options={"include_usage": True} if stream else {},
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
timeout=20,
|
timeout=20,
|
||||||
**(model_kwargs or dict()),
|
**(model_kwargs or dict()),
|
||||||
|
@ -175,7 +185,8 @@ def llm_thread(
|
||||||
|
|
||||||
aggregated_response = ""
|
aggregated_response = ""
|
||||||
if not stream:
|
if not stream:
|
||||||
aggregated_response = chat.choices[0].message.content
|
chunk = chat
|
||||||
|
aggregated_response = chunk.choices[0].message.content
|
||||||
g.send(aggregated_response)
|
g.send(aggregated_response)
|
||||||
else:
|
else:
|
||||||
for chunk in chat:
|
for chunk in chat:
|
||||||
|
@ -191,6 +202,11 @@ def llm_thread(
|
||||||
aggregated_response += text_chunk
|
aggregated_response += text_chunk
|
||||||
g.send(text_chunk)
|
g.send(text_chunk)
|
||||||
|
|
||||||
|
# Calculate cost of chat
|
||||||
|
input_tokens = chunk.usage.prompt_tokens if hasattr(chunk, "usage") and chunk.usage else 0
|
||||||
|
output_tokens = chunk.usage.completion_tokens if hasattr(chunk, "usage") and chunk.usage else 0
|
||||||
|
tracer["usage"] = get_chat_usage_metrics(model_name, input_tokens, output_tokens, tracer.get("usage"))
|
||||||
|
|
||||||
# Save conversation trace
|
# Save conversation trace
|
||||||
tracer["chat_model"] = model_name
|
tracer["chat_model"] = model_name
|
||||||
tracer["temperature"] = temperature
|
tracer["temperature"] = temperature
|
||||||
|
|
|
@ -5,7 +5,6 @@ import math
|
||||||
import mimetypes
|
import mimetypes
|
||||||
import os
|
import os
|
||||||
import queue
|
import queue
|
||||||
import re
|
|
||||||
import uuid
|
import uuid
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
@ -57,7 +56,7 @@ model_to_prompt_size = {
|
||||||
"gemini-1.5-flash": 20000,
|
"gemini-1.5-flash": 20000,
|
||||||
"gemini-1.5-pro": 20000,
|
"gemini-1.5-pro": 20000,
|
||||||
# Anthropic Models
|
# Anthropic Models
|
||||||
"claude-3-5-sonnet-20240620": 20000,
|
"claude-3-5-sonnet-20241022": 20000,
|
||||||
"claude-3-5-haiku-20241022": 20000,
|
"claude-3-5-haiku-20241022": 20000,
|
||||||
# Offline Models
|
# Offline Models
|
||||||
"bartowski/Meta-Llama-3.1-8B-Instruct-GGUF": 20000,
|
"bartowski/Meta-Llama-3.1-8B-Instruct-GGUF": 20000,
|
||||||
|
@ -213,6 +212,8 @@ class ChatEvent(Enum):
|
||||||
REFERENCES = "references"
|
REFERENCES = "references"
|
||||||
STATUS = "status"
|
STATUS = "status"
|
||||||
METADATA = "metadata"
|
METADATA = "metadata"
|
||||||
|
USAGE = "usage"
|
||||||
|
END_RESPONSE = "end_response"
|
||||||
|
|
||||||
|
|
||||||
def message_to_log(
|
def message_to_log(
|
||||||
|
|
|
@ -667,27 +667,37 @@ async def chat(
|
||||||
finally:
|
finally:
|
||||||
yield event_delimiter
|
yield event_delimiter
|
||||||
|
|
||||||
async def send_llm_response(response: str):
|
async def send_llm_response(response: str, usage: dict = None):
|
||||||
|
# Send Chat Response
|
||||||
async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""):
|
async for result in send_event(ChatEvent.START_LLM_RESPONSE, ""):
|
||||||
yield result
|
yield result
|
||||||
async for result in send_event(ChatEvent.MESSAGE, response):
|
async for result in send_event(ChatEvent.MESSAGE, response):
|
||||||
yield result
|
yield result
|
||||||
async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""):
|
async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""):
|
||||||
yield result
|
yield result
|
||||||
|
# Send Usage Metadata once llm interactions are complete
|
||||||
|
if usage:
|
||||||
|
async for event in send_event(ChatEvent.USAGE, usage):
|
||||||
|
yield event
|
||||||
|
async for result in send_event(ChatEvent.END_RESPONSE, ""):
|
||||||
|
yield result
|
||||||
|
|
||||||
def collect_telemetry():
|
def collect_telemetry():
|
||||||
# Gather chat response telemetry
|
# Gather chat response telemetry
|
||||||
nonlocal chat_metadata
|
nonlocal chat_metadata
|
||||||
latency = time.perf_counter() - start_time
|
latency = time.perf_counter() - start_time
|
||||||
cmd_set = set([cmd.value for cmd in conversation_commands])
|
cmd_set = set([cmd.value for cmd in conversation_commands])
|
||||||
|
cost = (tracer.get("usage", {}) or {}).get("cost", 0)
|
||||||
chat_metadata = chat_metadata or {}
|
chat_metadata = chat_metadata or {}
|
||||||
chat_metadata["conversation_command"] = cmd_set
|
chat_metadata["conversation_command"] = cmd_set
|
||||||
chat_metadata["agent"] = conversation.agent.slug if conversation.agent else None
|
chat_metadata["agent"] = conversation.agent.slug if conversation.agent else None
|
||||||
chat_metadata["latency"] = f"{latency:.3f}"
|
chat_metadata["latency"] = f"{latency:.3f}"
|
||||||
chat_metadata["ttft_latency"] = f"{ttft:.3f}"
|
chat_metadata["ttft_latency"] = f"{ttft:.3f}"
|
||||||
|
chat_metadata["usage"] = tracer.get("usage")
|
||||||
|
|
||||||
logger.info(f"Chat response time to first token: {ttft:.3f} seconds")
|
logger.info(f"Chat response time to first token: {ttft:.3f} seconds")
|
||||||
logger.info(f"Chat response total time: {latency:.3f} seconds")
|
logger.info(f"Chat response total time: {latency:.3f} seconds")
|
||||||
|
logger.info(f"Chat response cost: ${cost:.5f}")
|
||||||
update_telemetry_state(
|
update_telemetry_state(
|
||||||
request=request,
|
request=request,
|
||||||
telemetry_type="api",
|
telemetry_type="api",
|
||||||
|
@ -699,7 +709,7 @@ async def chat(
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_query_empty(q):
|
if is_query_empty(q):
|
||||||
async for result in send_llm_response("Please ask your query to get started."):
|
async for result in send_llm_response("Please ask your query to get started.", tracer.get("usage")):
|
||||||
yield result
|
yield result
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -713,7 +723,7 @@ async def chat(
|
||||||
create_new=body.create_new,
|
create_new=body.create_new,
|
||||||
)
|
)
|
||||||
if not conversation:
|
if not conversation:
|
||||||
async for result in send_llm_response(f"Conversation {conversation_id} not found"):
|
async for result in send_llm_response(f"Conversation {conversation_id} not found", tracer.get("usage")):
|
||||||
yield result
|
yield result
|
||||||
return
|
return
|
||||||
conversation_id = conversation.id
|
conversation_id = conversation.id
|
||||||
|
@ -777,7 +787,7 @@ async def chat(
|
||||||
await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd)
|
await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd)
|
||||||
q = q.replace(f"/{cmd.value}", "").strip()
|
q = q.replace(f"/{cmd.value}", "").strip()
|
||||||
except HTTPException as e:
|
except HTTPException as e:
|
||||||
async for result in send_llm_response(str(e.detail)):
|
async for result in send_llm_response(str(e.detail), tracer.get("usage")):
|
||||||
yield result
|
yield result
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -834,7 +844,7 @@ async def chat(
|
||||||
agent_has_entries = await EntryAdapters.aagent_has_entries(agent)
|
agent_has_entries = await EntryAdapters.aagent_has_entries(agent)
|
||||||
if len(file_filters) == 0 and not agent_has_entries:
|
if len(file_filters) == 0 and not agent_has_entries:
|
||||||
response_log = "No files selected for summarization. Please add files using the section on the left."
|
response_log = "No files selected for summarization. Please add files using the section on the left."
|
||||||
async for result in send_llm_response(response_log):
|
async for result in send_llm_response(response_log, tracer.get("usage")):
|
||||||
yield result
|
yield result
|
||||||
else:
|
else:
|
||||||
async for response in generate_summary_from_files(
|
async for response in generate_summary_from_files(
|
||||||
|
@ -853,7 +863,7 @@ async def chat(
|
||||||
else:
|
else:
|
||||||
if isinstance(response, str):
|
if isinstance(response, str):
|
||||||
response_log = response
|
response_log = response
|
||||||
async for result in send_llm_response(response):
|
async for result in send_llm_response(response, tracer.get("usage")):
|
||||||
yield result
|
yield result
|
||||||
|
|
||||||
await sync_to_async(save_to_conversation_log)(
|
await sync_to_async(save_to_conversation_log)(
|
||||||
|
@ -880,7 +890,7 @@ async def chat(
|
||||||
conversation_config = await ConversationAdapters.aget_default_conversation_config(user)
|
conversation_config = await ConversationAdapters.aget_default_conversation_config(user)
|
||||||
model_type = conversation_config.model_type
|
model_type = conversation_config.model_type
|
||||||
formatted_help = help_message.format(model=model_type, version=state.khoj_version, device=get_device())
|
formatted_help = help_message.format(model=model_type, version=state.khoj_version, device=get_device())
|
||||||
async for result in send_llm_response(formatted_help):
|
async for result in send_llm_response(formatted_help, tracer.get("usage")):
|
||||||
yield result
|
yield result
|
||||||
return
|
return
|
||||||
# Adding specification to search online specifically on khoj.dev pages.
|
# Adding specification to search online specifically on khoj.dev pages.
|
||||||
|
@ -895,7 +905,7 @@ async def chat(
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error scheduling task {q} for {user.email}: {e}")
|
logger.error(f"Error scheduling task {q} for {user.email}: {e}")
|
||||||
error_message = f"Unable to create automation. Ensure the automation doesn't already exist."
|
error_message = f"Unable to create automation. Ensure the automation doesn't already exist."
|
||||||
async for result in send_llm_response(error_message):
|
async for result in send_llm_response(error_message, tracer.get("usage")):
|
||||||
yield result
|
yield result
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -916,7 +926,7 @@ async def chat(
|
||||||
raw_query_files=raw_query_files,
|
raw_query_files=raw_query_files,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
async for result in send_llm_response(llm_response):
|
async for result in send_llm_response(llm_response, tracer.get("usage")):
|
||||||
yield result
|
yield result
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -963,7 +973,7 @@ async def chat(
|
||||||
yield result
|
yield result
|
||||||
|
|
||||||
if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user):
|
if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user):
|
||||||
async for result in send_llm_response(f"{no_entries_found.format()}"):
|
async for result in send_llm_response(f"{no_entries_found.format()}", tracer.get("usage")):
|
||||||
yield result
|
yield result
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -1105,7 +1115,7 @@ async def chat(
|
||||||
"detail": improved_image_prompt,
|
"detail": improved_image_prompt,
|
||||||
"image": None,
|
"image": None,
|
||||||
}
|
}
|
||||||
async for result in send_llm_response(json.dumps(content_obj)):
|
async for result in send_llm_response(json.dumps(content_obj), tracer.get("usage")):
|
||||||
yield result
|
yield result
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -1132,7 +1142,7 @@ async def chat(
|
||||||
"inferredQueries": [improved_image_prompt],
|
"inferredQueries": [improved_image_prompt],
|
||||||
"image": generated_image,
|
"image": generated_image,
|
||||||
}
|
}
|
||||||
async for result in send_llm_response(json.dumps(content_obj)):
|
async for result in send_llm_response(json.dumps(content_obj), tracer.get("usage")):
|
||||||
yield result
|
yield result
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -1166,7 +1176,7 @@ async def chat(
|
||||||
diagram_description = excalidraw_diagram_description
|
diagram_description = excalidraw_diagram_description
|
||||||
else:
|
else:
|
||||||
error_message = "Failed to generate diagram. Please try again later."
|
error_message = "Failed to generate diagram. Please try again later."
|
||||||
async for result in send_llm_response(error_message):
|
async for result in send_llm_response(error_message, tracer.get("usage")):
|
||||||
yield result
|
yield result
|
||||||
|
|
||||||
await sync_to_async(save_to_conversation_log)(
|
await sync_to_async(save_to_conversation_log)(
|
||||||
|
@ -1213,7 +1223,7 @@ async def chat(
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
async for result in send_llm_response(json.dumps(content_obj)):
|
async for result in send_llm_response(json.dumps(content_obj), tracer.get("usage")):
|
||||||
yield result
|
yield result
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -1252,6 +1262,11 @@ async def chat(
|
||||||
if item is None:
|
if item is None:
|
||||||
async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""):
|
async for result in send_event(ChatEvent.END_LLM_RESPONSE, ""):
|
||||||
yield result
|
yield result
|
||||||
|
# Send Usage Metadata once llm interactions are complete
|
||||||
|
async for event in send_event(ChatEvent.USAGE, tracer.get("usage")):
|
||||||
|
yield event
|
||||||
|
async for result in send_event(ChatEvent.END_RESPONSE, ""):
|
||||||
|
yield result
|
||||||
logger.debug("Finished streaming response")
|
logger.debug("Finished streaming response")
|
||||||
return
|
return
|
||||||
if not connection_alive or not continue_stream:
|
if not connection_alive or not continue_stream:
|
||||||
|
|
|
@ -1770,6 +1770,7 @@ Manage your automations [here](/automations).
|
||||||
class MessageProcessor:
|
class MessageProcessor:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.references = {}
|
self.references = {}
|
||||||
|
self.usage = {}
|
||||||
self.raw_response = ""
|
self.raw_response = ""
|
||||||
|
|
||||||
def convert_message_chunk_to_json(self, raw_chunk: str) -> Dict[str, Any]:
|
def convert_message_chunk_to_json(self, raw_chunk: str) -> Dict[str, Any]:
|
||||||
|
@ -1793,6 +1794,8 @@ class MessageProcessor:
|
||||||
chunk_type = ChatEvent(chunk["type"])
|
chunk_type = ChatEvent(chunk["type"])
|
||||||
if chunk_type == ChatEvent.REFERENCES:
|
if chunk_type == ChatEvent.REFERENCES:
|
||||||
self.references = chunk["data"]
|
self.references = chunk["data"]
|
||||||
|
elif chunk_type == ChatEvent.USAGE:
|
||||||
|
self.usage = chunk["data"]
|
||||||
elif chunk_type == ChatEvent.MESSAGE:
|
elif chunk_type == ChatEvent.MESSAGE:
|
||||||
chunk_data = chunk["data"]
|
chunk_data = chunk["data"]
|
||||||
if isinstance(chunk_data, dict):
|
if isinstance(chunk_data, dict):
|
||||||
|
@ -1837,7 +1840,7 @@ async def read_chat_stream(response_iterator: AsyncGenerator[str, None]) -> Dict
|
||||||
if buffer:
|
if buffer:
|
||||||
processor.process_message_chunk(buffer)
|
processor.process_message_chunk(buffer)
|
||||||
|
|
||||||
return {"response": processor.raw_response, "references": processor.references}
|
return {"response": processor.raw_response, "references": processor.references, "usage": processor.usage}
|
||||||
|
|
||||||
|
|
||||||
def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False):
|
def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False):
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
app_root_directory = Path(__file__).parent.parent.parent
|
app_root_directory = Path(__file__).parent.parent.parent
|
||||||
web_directory = app_root_directory / "khoj/interface/web/"
|
web_directory = app_root_directory / "khoj/interface/web/"
|
||||||
|
@ -31,3 +32,19 @@ default_config = {
|
||||||
"image": {"encoder": "sentence-transformers/clip-ViT-B-32", "model_directory": "~/.khoj/search/image/"},
|
"image": {"encoder": "sentence-transformers/clip-ViT-B-32", "model_directory": "~/.khoj/search/image/"},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
model_to_cost: Dict[str, Dict[str, float]] = {
|
||||||
|
# OpenAI Pricing: https://openai.com/api/pricing/
|
||||||
|
"gpt-4o": {"input": 2.50, "output": 10.00},
|
||||||
|
"gpt-4o-mini": {"input": 0.15, "output": 0.60},
|
||||||
|
"o1-preview": {"input": 15.0, "output": 60.00},
|
||||||
|
"o1-mini": {"input": 3.0, "output": 12.0},
|
||||||
|
# Gemini Pricing: https://ai.google.dev/pricing
|
||||||
|
"gemini-1.5-flash": {"input": 0.075, "output": 0.30},
|
||||||
|
"gemini-1.5-flash-002": {"input": 0.075, "output": 0.30},
|
||||||
|
"gemini-1.5-pro": {"input": 1.25, "output": 5.00},
|
||||||
|
"gemini-1.5-pro-002": {"input": 1.25, "output": 5.00},
|
||||||
|
# Anthropic Pricing: https://www.anthropic.com/pricing#anthropic-api_
|
||||||
|
"claude-3-5-sonnet-20241022": {"input": 3.0, "output": 15.0},
|
||||||
|
"claude-3-5-haiku-20241022": {"input": 1.0, "output": 5.0},
|
||||||
|
}
|
||||||
|
|
|
@ -540,3 +540,27 @@ def get_country_code_from_timezone(tz: str) -> str:
|
||||||
def get_country_name_from_timezone(tz: str) -> str:
|
def get_country_name_from_timezone(tz: str) -> str:
|
||||||
"""Get country name from timezone"""
|
"""Get country name from timezone"""
|
||||||
return country_names.get(get_country_code_from_timezone(tz), "United States")
|
return country_names.get(get_country_code_from_timezone(tz), "United States")
|
||||||
|
|
||||||
|
|
||||||
|
def get_cost_of_chat_message(model_name: str, input_tokens: int = 0, output_tokens: int = 0, prev_cost: float = 0.0):
|
||||||
|
"""
|
||||||
|
Calculate cost of chat message based on input and output tokens
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Calculate cost of input and output tokens. Costs are per million tokens
|
||||||
|
input_cost = constants.model_to_cost.get(model_name, {}).get("input", 0) * (input_tokens / 1e6)
|
||||||
|
output_cost = constants.model_to_cost.get(model_name, {}).get("output", 0) * (output_tokens / 1e6)
|
||||||
|
|
||||||
|
return input_cost + output_cost + prev_cost
|
||||||
|
|
||||||
|
|
||||||
|
def get_chat_usage_metrics(model_name: str, input_tokens: int = 0, output_tokens: int = 0, usage: dict = {}):
|
||||||
|
"""
|
||||||
|
Get usage metrics for chat message based on input and output tokens
|
||||||
|
"""
|
||||||
|
prev_usage = usage or {"input_tokens": 0, "output_tokens": 0, "cost": 0.0}
|
||||||
|
return {
|
||||||
|
"input_tokens": prev_usage["input_tokens"] + input_tokens,
|
||||||
|
"output_tokens": prev_usage["output_tokens"] + output_tokens,
|
||||||
|
"cost": get_cost_of_chat_message(model_name, input_tokens, output_tokens, prev_cost=prev_usage["cost"]),
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue