feat(openai): Populate OpenAI service tier token details (#32721)

This commit is contained in:
Jacob Lee
2025-10-16 12:14:57 -07:00
committed by GitHub
parent 13259a109a
commit 6d73003b17
2 changed files with 112 additions and 20 deletions

View File

@@ -917,7 +917,9 @@ class BaseChatOpenAI(BaseChatModel):
)
usage_metadata: UsageMetadata | None = (
_create_usage_metadata(token_usage) if token_usage else None
_create_usage_metadata(token_usage, chunk.get("service_tier"))
if token_usage
else None
)
if len(choices) == 0:
# logprobs is implicitly None
@@ -1289,11 +1291,14 @@ class BaseChatOpenAI(BaseChatModel):
raise TypeError(msg)
token_usage = response_dict.get("usage")
service_tier = response_dict.get("service_tier")
for res in choices:
message = _convert_dict_to_message(res["message"])
if token_usage and isinstance(message, AIMessage):
message.usage_metadata = _create_usage_metadata(token_usage)
message.usage_metadata = _create_usage_metadata(
token_usage, service_tier
)
generation_info = generation_info or {}
generation_info["finish_reason"] = (
res.get("finish_reason")
@@ -1312,8 +1317,8 @@ class BaseChatOpenAI(BaseChatModel):
}
if "id" in response_dict:
llm_output["id"] = response_dict["id"]
if "service_tier" in response_dict:
llm_output["service_tier"] = response_dict["service_tier"]
if service_tier:
llm_output["service_tier"] = service_tier
if isinstance(response, openai.BaseModel) and getattr(
response, "choices", None
@@ -3387,26 +3392,40 @@ class OpenAIRefusalError(Exception):
"""
def _create_usage_metadata(oai_token_usage: dict) -> UsageMetadata:
def _create_usage_metadata(
oai_token_usage: dict, service_tier: str | None = None
) -> UsageMetadata:
input_tokens = oai_token_usage.get("prompt_tokens") or 0
output_tokens = oai_token_usage.get("completion_tokens") or 0
total_tokens = oai_token_usage.get("total_tokens") or input_tokens + output_tokens
if service_tier not in {"priority", "flex"}:
service_tier = None
service_tier_prefix = f"{service_tier}_" if service_tier else ""
input_token_details: dict = {
"audio": (oai_token_usage.get("prompt_tokens_details") or {}).get(
"audio_tokens"
),
"cache_read": (oai_token_usage.get("prompt_tokens_details") or {}).get(
"cached_tokens"
),
f"{service_tier_prefix}cache_read": (
oai_token_usage.get("prompt_tokens_details") or {}
).get("cached_tokens"),
}
output_token_details: dict = {
"audio": (oai_token_usage.get("completion_tokens_details") or {}).get(
"audio_tokens"
),
"reasoning": (oai_token_usage.get("completion_tokens_details") or {}).get(
"reasoning_tokens"
),
f"{service_tier_prefix}reasoning": (
oai_token_usage.get("completion_tokens_details") or {}
).get("reasoning_tokens"),
}
if service_tier is not None:
# Avoid counting cache and reasoning tokens towards the service tier token
# counts, since service tier tokens are already priced differently
input_token_details[service_tier] = input_tokens - input_token_details.get(
f"{service_tier_prefix}cache_read", 0
)
output_token_details[service_tier] = output_tokens - output_token_details.get(
f"{service_tier_prefix}reasoning", 0
)
return UsageMetadata(
input_tokens=input_tokens,
output_tokens=output_tokens,
@@ -3420,20 +3439,34 @@ def _create_usage_metadata(oai_token_usage: dict) -> UsageMetadata:
)
def _create_usage_metadata_responses(oai_token_usage: dict) -> UsageMetadata:
def _create_usage_metadata_responses(
oai_token_usage: dict, service_tier: str | None = None
) -> UsageMetadata:
input_tokens = oai_token_usage.get("input_tokens", 0)
output_tokens = oai_token_usage.get("output_tokens", 0)
total_tokens = oai_token_usage.get("total_tokens", input_tokens + output_tokens)
if service_tier not in {"priority", "flex"}:
service_tier = None
service_tier_prefix = f"{service_tier}_" if service_tier else ""
output_token_details: dict = {
"reasoning": (oai_token_usage.get("output_tokens_details") or {}).get(
"reasoning_tokens"
)
f"{service_tier_prefix}reasoning": (
oai_token_usage.get("output_tokens_details") or {}
).get("reasoning_tokens")
}
input_token_details: dict = {
"cache_read": (oai_token_usage.get("input_tokens_details") or {}).get(
"cached_tokens"
)
f"{service_tier_prefix}cache_read": (
oai_token_usage.get("input_tokens_details") or {}
).get("cached_tokens")
}
if service_tier is not None:
# Avoid counting cache and reasoning tokens towards the service tier token
# counts, since service tier tokens are already priced differently
output_token_details[service_tier] = output_tokens - output_token_details.get(
f"{service_tier_prefix}reasoning", 0
)
input_token_details[service_tier] = input_tokens - input_token_details.get(
f"{service_tier_prefix}cache_read", 0
)
return UsageMetadata(
input_tokens=input_tokens,
output_tokens=output_tokens,
@@ -3957,7 +3990,9 @@ def _construct_lc_result_from_responses_api(
response_metadata["model_provider"] = "openai"
response_metadata["model_name"] = response_metadata.get("model")
if response.usage:
usage_metadata = _create_usage_metadata_responses(response.usage.model_dump())
usage_metadata = _create_usage_metadata_responses(
response.usage.model_dump(), response.service_tier
)
else:
usage_metadata = None

View File

@@ -194,13 +194,35 @@ def test_openai_invoke() -> None:
result = llm.invoke("Hello", config={"tags": ["foo"]})
assert isinstance(result.content, str)
usage_metadata = result.usage_metadata # type: ignore[attr-defined]
# assert no response headers if include_response_headers is not set
assert "headers" not in result.response_metadata
assert usage_metadata is not None
flex_input = usage_metadata.get("input_token_details", {}).get("flex")
assert isinstance(flex_input, int)
assert flex_input > 0
assert flex_input == usage_metadata.get("input_tokens")
flex_output = usage_metadata.get("output_token_details", {}).get("flex")
assert isinstance(flex_output, int)
assert flex_output > 0
# GPT-5-nano/reasoning model specific. Remove if model used in test changes.
flex_reasoning = usage_metadata.get("output_token_details", {}).get(
"flex_reasoning"
)
assert isinstance(flex_reasoning, int)
assert flex_reasoning > 0
assert flex_reasoning + flex_output == usage_metadata.get("output_tokens")
@pytest.mark.flaky(retries=3, delay=1)
def test_stream() -> None:
"""Test streaming tokens from OpenAI."""
llm = ChatOpenAI(model="gpt-4.1-mini")
llm = ChatOpenAI(
model="gpt-5-nano",
service_tier="flex", # Also test service_tier
max_retries=3, # Add retries for 503 capacity errors
)
full: BaseMessageChunk | None = None
for chunk in llm.stream("I'm Pickle Rick"):
@@ -236,6 +258,19 @@ def test_stream() -> None:
assert aggregate.usage_metadata["input_tokens"] > 0
assert aggregate.usage_metadata["output_tokens"] > 0
assert aggregate.usage_metadata["total_tokens"] > 0
assert aggregate.usage_metadata.get("input_token_details", {}).get("flex", 0) > 0 # type: ignore[operator]
assert aggregate.usage_metadata.get("output_token_details", {}).get("flex", 0) > 0 # type: ignore[operator]
assert (
aggregate.usage_metadata.get("output_token_details", {}).get( # type: ignore[operator]
"flex_reasoning", 0
)
> 0
)
assert aggregate.usage_metadata.get("output_token_details", {}).get( # type: ignore[operator]
"flex_reasoning", 0
) + aggregate.usage_metadata.get("output_token_details", {}).get(
"flex", 0
) == aggregate.usage_metadata.get("output_tokens")
async def test_astream() -> None:
@@ -308,6 +343,28 @@ async def test_astream() -> None:
await _test_stream(llm.astream("Hello", stream_usage=False), expect_usage=False)
@pytest.mark.parametrize("streaming", [False, True])
def test_flex_usage_responses(streaming: bool) -> None:
llm = ChatOpenAI(
model="gpt-5-nano",
service_tier="flex",
max_retries=3,
use_responses_api=True,
streaming=streaming,
)
result = llm.invoke("Hello")
assert result.usage_metadata
flex_input = result.usage_metadata.get("input_token_details", {}).get("flex")
flex_output = result.usage_metadata.get("output_token_details", {}).get("flex")
flex_reasoning = result.usage_metadata.get("output_token_details", {}).get(
"flex_reasoning"
)
assert isinstance(flex_input, int)
assert isinstance(flex_output, int)
assert isinstance(flex_reasoning, int)
assert flex_output + flex_reasoning == result.usage_metadata.get("output_tokens")
async def test_abatch_tags() -> None:
"""Test batch tokens from ChatOpenAI."""
llm = ChatOpenAI()