Community : Add OpenAI prompt caching and reasoning tokens tracking (#27135)

Added Token tracking for OpenAI's prompt caching and reasoning tokens
Costs updated from https://openai.com/api/pricing/

usage example
```python
from langchain_community.callbacks import get_openai_callback
from langchain_openai import ChatOpenAI
llm = ChatOpenAI(model_name="o1-mini",temperature=1)

with get_openai_callback() as cb:
    response = llm.invoke("hi "*1500)
    print(cb)
```
Output
```
Tokens Used: 1720
	Prompt Tokens: 1508
		Prompt Tokens Cached: 1408
	Completion Tokens: 212
		Reasoning Tokens: 192
Successful Requests: 1
Total Cost (USD): $0.0049559999999999995
```

---------

Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
Vignesh A 2024-12-19 20:01:13 +05:30 committed by GitHub
parent 97f1e1d39f
commit 4c9acdfbf1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 123 additions and 10 deletions

View File

@ -1,8 +1,10 @@
"""Callback Handler that prints to std out.""" """Callback Handler that prints to std out."""
import threading import threading
from enum import Enum, auto
from typing import Any, Dict, List from typing import Any, Dict, List
from langchain_core._api import warn_deprecated
from langchain_core.callbacks import BaseCallbackHandler from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.messages import AIMessage from langchain_core.messages import AIMessage
from langchain_core.outputs import ChatGeneration, LLMResult from langchain_core.outputs import ChatGeneration, LLMResult
@ -10,26 +12,34 @@ from langchain_core.outputs import ChatGeneration, LLMResult
MODEL_COST_PER_1K_TOKENS = { MODEL_COST_PER_1K_TOKENS = {
# OpenAI o1-preview input # OpenAI o1-preview input
"o1-preview": 0.015, "o1-preview": 0.015,
"o1-preview-cached": 0.0075,
"o1-preview-2024-09-12": 0.015, "o1-preview-2024-09-12": 0.015,
"o1-preview-2024-09-12-cached": 0.0075,
# OpenAI o1-preview output # OpenAI o1-preview output
"o1-preview-completion": 0.06, "o1-preview-completion": 0.06,
"o1-preview-2024-09-12-completion": 0.06, "o1-preview-2024-09-12-completion": 0.06,
# OpenAI o1-mini input # OpenAI o1-mini input
"o1-mini": 0.003, "o1-mini": 0.003,
"o1-mini-cached": 0.0015,
"o1-mini-2024-09-12": 0.003, "o1-mini-2024-09-12": 0.003,
"o1-mini-2024-09-12-cached": 0.0015,
# OpenAI o1-mini output # OpenAI o1-mini output
"o1-mini-completion": 0.012, "o1-mini-completion": 0.012,
"o1-mini-2024-09-12-completion": 0.012, "o1-mini-2024-09-12-completion": 0.012,
# GPT-4o-mini input # GPT-4o-mini input
"gpt-4o-mini": 0.00015, "gpt-4o-mini": 0.00015,
"gpt-4o-mini-cached": 0.000075,
"gpt-4o-mini-2024-07-18": 0.00015, "gpt-4o-mini-2024-07-18": 0.00015,
"gpt-4o-mini-2024-07-18-cached": 0.000075,
# GPT-4o-mini output # GPT-4o-mini output
"gpt-4o-mini-completion": 0.0006, "gpt-4o-mini-completion": 0.0006,
"gpt-4o-mini-2024-07-18-completion": 0.0006, "gpt-4o-mini-2024-07-18-completion": 0.0006,
# GPT-4o input # GPT-4o input
"gpt-4o": 0.0025, "gpt-4o": 0.0025,
"gpt-4o-cached": 0.00125,
"gpt-4o-2024-05-13": 0.005, "gpt-4o-2024-05-13": 0.005,
"gpt-4o-2024-08-06": 0.0025, "gpt-4o-2024-08-06": 0.0025,
"gpt-4o-2024-08-06-cached": 0.00125,
"gpt-4o-2024-11-20": 0.0025, "gpt-4o-2024-11-20": 0.0025,
# GPT-4o output # GPT-4o output
"gpt-4o-completion": 0.01, "gpt-4o-completion": 0.01,
@ -140,9 +150,19 @@ MODEL_COST_PER_1K_TOKENS = {
} }
class TokenType(Enum):
"""Token type enum."""
PROMPT = auto()
PROMPT_CACHED = auto()
COMPLETION = auto()
def standardize_model_name( def standardize_model_name(
model_name: str, model_name: str,
is_completion: bool = False, is_completion: bool = False,
*,
token_type: TokenType = TokenType.PROMPT,
) -> str: ) -> str:
""" """
Standardize the model name to a format that can be used in the OpenAI API. Standardize the model name to a format that can be used in the OpenAI API.
@ -150,12 +170,24 @@ def standardize_model_name(
Args: Args:
model_name: Model name to standardize. model_name: Model name to standardize.
is_completion: Whether the model is used for completion or not. is_completion: Whether the model is used for completion or not.
Defaults to False. Defaults to False. Deprecated in favor of ``token_type``.
token_type: Token type. Defaults to ``TokenType.PROMPT``.
Returns: Returns:
Standardized model name. Standardized model name.
""" """
if is_completion:
warn_deprecated(
since="0.3.13",
message=(
"is_completion is deprecated. Use token_type instead. Example:\n\n"
"from langchain_community.callbacks.openai_info import TokenType\n\n"
"standardize_model_name('gpt-4o', token_type=TokenType.COMPLETION)\n"
),
removal="1.0",
)
token_type = TokenType.COMPLETION
model_name = model_name.lower() model_name = model_name.lower()
if ".ft-" in model_name: if ".ft-" in model_name:
model_name = model_name.split(".ft-")[0] + "-azure-finetuned" model_name = model_name.split(".ft-")[0] + "-azure-finetuned"
@ -163,7 +195,7 @@ def standardize_model_name(
model_name = model_name.split(":")[0] + "-finetuned-legacy" model_name = model_name.split(":")[0] + "-finetuned-legacy"
if "ft:" in model_name: if "ft:" in model_name:
model_name = model_name.split(":")[1] + "-finetuned" model_name = model_name.split(":")[1] + "-finetuned"
if is_completion and ( if token_type == TokenType.COMPLETION and (
model_name.startswith("gpt-4") model_name.startswith("gpt-4")
or model_name.startswith("gpt-3.5") or model_name.startswith("gpt-3.5")
or model_name.startswith("gpt-35") or model_name.startswith("gpt-35")
@ -171,12 +203,20 @@ def standardize_model_name(
or ("finetuned" in model_name and "legacy" not in model_name) or ("finetuned" in model_name and "legacy" not in model_name)
): ):
return model_name + "-completion" return model_name + "-completion"
if token_type == TokenType.PROMPT_CACHED and (
model_name.startswith("gpt-4o") or model_name.startswith("o1")
):
return model_name + "-cached"
else: else:
return model_name return model_name
def get_openai_token_cost_for_model( def get_openai_token_cost_for_model(
model_name: str, num_tokens: int, is_completion: bool = False model_name: str,
num_tokens: int,
is_completion: bool = False,
*,
token_type: TokenType = TokenType.PROMPT,
) -> float: ) -> float:
""" """
Get the cost in USD for a given model and number of tokens. Get the cost in USD for a given model and number of tokens.
@ -185,12 +225,24 @@ def get_openai_token_cost_for_model(
model_name: Name of the model model_name: Name of the model
num_tokens: Number of tokens. num_tokens: Number of tokens.
is_completion: Whether the model is used for completion or not. is_completion: Whether the model is used for completion or not.
Defaults to False. Defaults to False. Deprecated in favor of ``token_type``.
token_type: Token type. Defaults to ``TokenType.PROMPT``.
Returns: Returns:
Cost in USD. Cost in USD.
""" """
model_name = standardize_model_name(model_name, is_completion=is_completion) if is_completion:
warn_deprecated(
since="0.3.13",
message=(
"is_completion is deprecated. Use token_type instead. Example:\n\n"
"from langchain_community.callbacks.openai_info import TokenType\n\n"
"get_openai_token_cost_for_model('gpt-4o', 10, token_type=TokenType.COMPLETION)\n" # noqa: E501
),
removal="1.0",
)
token_type = TokenType.COMPLETION
model_name = standardize_model_name(model_name, token_type=token_type)
if model_name not in MODEL_COST_PER_1K_TOKENS: if model_name not in MODEL_COST_PER_1K_TOKENS:
raise ValueError( raise ValueError(
f"Unknown model: {model_name}. Please provide a valid OpenAI model name." f"Unknown model: {model_name}. Please provide a valid OpenAI model name."
@ -204,7 +256,9 @@ class OpenAICallbackHandler(BaseCallbackHandler):
total_tokens: int = 0 total_tokens: int = 0
prompt_tokens: int = 0 prompt_tokens: int = 0
prompt_tokens_cached: int = 0
completion_tokens: int = 0 completion_tokens: int = 0
reasoning_tokens: int = 0
successful_requests: int = 0 successful_requests: int = 0
total_cost: float = 0.0 total_cost: float = 0.0
@ -216,7 +270,9 @@ class OpenAICallbackHandler(BaseCallbackHandler):
return ( return (
f"Tokens Used: {self.total_tokens}\n" f"Tokens Used: {self.total_tokens}\n"
f"\tPrompt Tokens: {self.prompt_tokens}\n" f"\tPrompt Tokens: {self.prompt_tokens}\n"
f"\t\tPrompt Tokens Cached: {self.prompt_tokens_cached}\n"
f"\tCompletion Tokens: {self.completion_tokens}\n" f"\tCompletion Tokens: {self.completion_tokens}\n"
f"\t\tReasoning Tokens: {self.reasoning_tokens}\n"
f"Successful Requests: {self.successful_requests}\n" f"Successful Requests: {self.successful_requests}\n"
f"Total Cost (USD): ${self.total_cost}" f"Total Cost (USD): ${self.total_cost}"
) )
@ -258,6 +314,10 @@ class OpenAICallbackHandler(BaseCallbackHandler):
else: else:
usage_metadata = None usage_metadata = None
response_metadata = None response_metadata = None
prompt_tokens_cached = 0
reasoning_tokens = 0
if usage_metadata: if usage_metadata:
token_usage = {"total_tokens": usage_metadata["total_tokens"]} token_usage = {"total_tokens": usage_metadata["total_tokens"]}
completion_tokens = usage_metadata["output_tokens"] completion_tokens = usage_metadata["output_tokens"]
@ -270,7 +330,12 @@ class OpenAICallbackHandler(BaseCallbackHandler):
model_name = standardize_model_name( model_name = standardize_model_name(
response.llm_output.get("model_name", "") response.llm_output.get("model_name", "")
) )
if "cache_read" in usage_metadata.get("input_token_details", {}):
prompt_tokens_cached = usage_metadata["input_token_details"][
"cache_read"
]
if "reasoning" in usage_metadata.get("output_token_details", {}):
reasoning_tokens = usage_metadata["output_token_details"]["reasoning"]
else: else:
if response.llm_output is None: if response.llm_output is None:
return None return None
@ -287,11 +352,19 @@ class OpenAICallbackHandler(BaseCallbackHandler):
model_name = standardize_model_name( model_name = standardize_model_name(
response.llm_output.get("model_name", "") response.llm_output.get("model_name", "")
) )
if model_name in MODEL_COST_PER_1K_TOKENS: if model_name in MODEL_COST_PER_1K_TOKENS:
completion_cost = get_openai_token_cost_for_model( uncached_prompt_tokens = prompt_tokens - prompt_tokens_cached
model_name, completion_tokens, is_completion=True uncached_prompt_cost = get_openai_token_cost_for_model(
model_name, uncached_prompt_tokens, token_type=TokenType.PROMPT
)
cached_prompt_cost = get_openai_token_cost_for_model(
model_name, prompt_tokens_cached, token_type=TokenType.PROMPT_CACHED
)
prompt_cost = uncached_prompt_cost + cached_prompt_cost
completion_cost = get_openai_token_cost_for_model(
model_name, completion_tokens, token_type=TokenType.COMPLETION
) )
prompt_cost = get_openai_token_cost_for_model(model_name, prompt_tokens)
else: else:
completion_cost = 0 completion_cost = 0
prompt_cost = 0 prompt_cost = 0
@ -301,7 +374,9 @@ class OpenAICallbackHandler(BaseCallbackHandler):
self.total_cost += prompt_cost + completion_cost self.total_cost += prompt_cost + completion_cost
self.total_tokens += token_usage.get("total_tokens", 0) self.total_tokens += token_usage.get("total_tokens", 0)
self.prompt_tokens += prompt_tokens self.prompt_tokens += prompt_tokens
self.prompt_tokens_cached += prompt_tokens_cached
self.completion_tokens += completion_tokens self.completion_tokens += completion_tokens
self.reasoning_tokens += reasoning_tokens
self.successful_requests += 1 self.successful_requests += 1
def __copy__(self) -> "OpenAICallbackHandler": def __copy__(self) -> "OpenAICallbackHandler":

View File

@ -3,7 +3,8 @@ from uuid import uuid4
import numpy as np import numpy as np
import pytest import pytest
from langchain_core.outputs import LLMResult from langchain_core.messages import AIMessage
from langchain_core.outputs import ChatGeneration, LLMResult
from langchain_core.utils.pydantic import get_fields from langchain_core.utils.pydantic import get_fields
from langchain_community.callbacks import OpenAICallbackHandler from langchain_community.callbacks import OpenAICallbackHandler
@ -35,6 +36,43 @@ def test_on_llm_end(handler: OpenAICallbackHandler) -> None:
assert handler.total_cost > 0 assert handler.total_cost > 0
def test_on_llm_end_with_chat_generation(handler: OpenAICallbackHandler) -> None:
response = LLMResult(
generations=[
[
ChatGeneration(
text="Hello, world!",
message=AIMessage(
content="Hello, world!",
usage_metadata={
"input_tokens": 2,
"output_tokens": 2,
"total_tokens": 4,
"input_token_details": {
"cache_read": 1,
},
"output_token_details": {
"reasoning": 1,
},
},
),
)
]
],
llm_output={
"model_name": get_fields(BaseOpenAI)["model_name"].default,
},
)
handler.on_llm_end(response)
assert handler.successful_requests == 1
assert handler.total_tokens == 4
assert handler.prompt_tokens == 2
assert handler.prompt_tokens_cached == 1
assert handler.completion_tokens == 2
assert handler.reasoning_tokens == 1
assert handler.total_cost > 0
def test_on_llm_end_custom_model(handler: OpenAICallbackHandler) -> None: def test_on_llm_end_custom_model(handler: OpenAICallbackHandler) -> None:
response = LLMResult( response = LLMResult(
generations=[], generations=[],