mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
feat(openai): Populate OpenAI service tier token details (#32721)
This commit is contained in:
@@ -917,7 +917,9 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
)
|
||||
|
||||
usage_metadata: UsageMetadata | None = (
|
||||
_create_usage_metadata(token_usage) if token_usage else None
|
||||
_create_usage_metadata(token_usage, chunk.get("service_tier"))
|
||||
if token_usage
|
||||
else None
|
||||
)
|
||||
if len(choices) == 0:
|
||||
# logprobs is implicitly None
|
||||
@@ -1289,11 +1291,14 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
raise TypeError(msg)
|
||||
|
||||
token_usage = response_dict.get("usage")
|
||||
service_tier = response_dict.get("service_tier")
|
||||
|
||||
for res in choices:
|
||||
message = _convert_dict_to_message(res["message"])
|
||||
if token_usage and isinstance(message, AIMessage):
|
||||
message.usage_metadata = _create_usage_metadata(token_usage)
|
||||
message.usage_metadata = _create_usage_metadata(
|
||||
token_usage, service_tier
|
||||
)
|
||||
generation_info = generation_info or {}
|
||||
generation_info["finish_reason"] = (
|
||||
res.get("finish_reason")
|
||||
@@ -1312,8 +1317,8 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
}
|
||||
if "id" in response_dict:
|
||||
llm_output["id"] = response_dict["id"]
|
||||
if "service_tier" in response_dict:
|
||||
llm_output["service_tier"] = response_dict["service_tier"]
|
||||
if service_tier:
|
||||
llm_output["service_tier"] = service_tier
|
||||
|
||||
if isinstance(response, openai.BaseModel) and getattr(
|
||||
response, "choices", None
|
||||
@@ -3387,26 +3392,40 @@ class OpenAIRefusalError(Exception):
|
||||
"""
|
||||
|
||||
|
||||
def _create_usage_metadata(oai_token_usage: dict) -> UsageMetadata:
|
||||
def _create_usage_metadata(
|
||||
oai_token_usage: dict, service_tier: str | None = None
|
||||
) -> UsageMetadata:
|
||||
input_tokens = oai_token_usage.get("prompt_tokens") or 0
|
||||
output_tokens = oai_token_usage.get("completion_tokens") or 0
|
||||
total_tokens = oai_token_usage.get("total_tokens") or input_tokens + output_tokens
|
||||
if service_tier not in {"priority", "flex"}:
|
||||
service_tier = None
|
||||
service_tier_prefix = f"{service_tier}_" if service_tier else ""
|
||||
input_token_details: dict = {
|
||||
"audio": (oai_token_usage.get("prompt_tokens_details") or {}).get(
|
||||
"audio_tokens"
|
||||
),
|
||||
"cache_read": (oai_token_usage.get("prompt_tokens_details") or {}).get(
|
||||
"cached_tokens"
|
||||
),
|
||||
f"{service_tier_prefix}cache_read": (
|
||||
oai_token_usage.get("prompt_tokens_details") or {}
|
||||
).get("cached_tokens"),
|
||||
}
|
||||
output_token_details: dict = {
|
||||
"audio": (oai_token_usage.get("completion_tokens_details") or {}).get(
|
||||
"audio_tokens"
|
||||
),
|
||||
"reasoning": (oai_token_usage.get("completion_tokens_details") or {}).get(
|
||||
"reasoning_tokens"
|
||||
),
|
||||
f"{service_tier_prefix}reasoning": (
|
||||
oai_token_usage.get("completion_tokens_details") or {}
|
||||
).get("reasoning_tokens"),
|
||||
}
|
||||
if service_tier is not None:
|
||||
# Avoid counting cache and reasoning tokens towards the service tier token
|
||||
# counts, since service tier tokens are already priced differently
|
||||
input_token_details[service_tier] = input_tokens - input_token_details.get(
|
||||
f"{service_tier_prefix}cache_read", 0
|
||||
)
|
||||
output_token_details[service_tier] = output_tokens - output_token_details.get(
|
||||
f"{service_tier_prefix}reasoning", 0
|
||||
)
|
||||
return UsageMetadata(
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
@@ -3420,20 +3439,34 @@ def _create_usage_metadata(oai_token_usage: dict) -> UsageMetadata:
|
||||
)
|
||||
|
||||
|
||||
def _create_usage_metadata_responses(oai_token_usage: dict) -> UsageMetadata:
|
||||
def _create_usage_metadata_responses(
|
||||
oai_token_usage: dict, service_tier: str | None = None
|
||||
) -> UsageMetadata:
|
||||
input_tokens = oai_token_usage.get("input_tokens", 0)
|
||||
output_tokens = oai_token_usage.get("output_tokens", 0)
|
||||
total_tokens = oai_token_usage.get("total_tokens", input_tokens + output_tokens)
|
||||
if service_tier not in {"priority", "flex"}:
|
||||
service_tier = None
|
||||
service_tier_prefix = f"{service_tier}_" if service_tier else ""
|
||||
output_token_details: dict = {
|
||||
"reasoning": (oai_token_usage.get("output_tokens_details") or {}).get(
|
||||
"reasoning_tokens"
|
||||
)
|
||||
f"{service_tier_prefix}reasoning": (
|
||||
oai_token_usage.get("output_tokens_details") or {}
|
||||
).get("reasoning_tokens")
|
||||
}
|
||||
input_token_details: dict = {
|
||||
"cache_read": (oai_token_usage.get("input_tokens_details") or {}).get(
|
||||
"cached_tokens"
|
||||
)
|
||||
f"{service_tier_prefix}cache_read": (
|
||||
oai_token_usage.get("input_tokens_details") or {}
|
||||
).get("cached_tokens")
|
||||
}
|
||||
if service_tier is not None:
|
||||
# Avoid counting cache and reasoning tokens towards the service tier token
|
||||
# counts, since service tier tokens are already priced differently
|
||||
output_token_details[service_tier] = output_tokens - output_token_details.get(
|
||||
f"{service_tier_prefix}reasoning", 0
|
||||
)
|
||||
input_token_details[service_tier] = input_tokens - input_token_details.get(
|
||||
f"{service_tier_prefix}cache_read", 0
|
||||
)
|
||||
return UsageMetadata(
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
@@ -3957,7 +3990,9 @@ def _construct_lc_result_from_responses_api(
|
||||
response_metadata["model_provider"] = "openai"
|
||||
response_metadata["model_name"] = response_metadata.get("model")
|
||||
if response.usage:
|
||||
usage_metadata = _create_usage_metadata_responses(response.usage.model_dump())
|
||||
usage_metadata = _create_usage_metadata_responses(
|
||||
response.usage.model_dump(), response.service_tier
|
||||
)
|
||||
else:
|
||||
usage_metadata = None
|
||||
|
||||
|
||||
@@ -194,13 +194,35 @@ def test_openai_invoke() -> None:
|
||||
result = llm.invoke("Hello", config={"tags": ["foo"]})
|
||||
assert isinstance(result.content, str)
|
||||
|
||||
usage_metadata = result.usage_metadata # type: ignore[attr-defined]
|
||||
|
||||
# assert no response headers if include_response_headers is not set
|
||||
assert "headers" not in result.response_metadata
|
||||
assert usage_metadata is not None
|
||||
flex_input = usage_metadata.get("input_token_details", {}).get("flex")
|
||||
assert isinstance(flex_input, int)
|
||||
assert flex_input > 0
|
||||
assert flex_input == usage_metadata.get("input_tokens")
|
||||
flex_output = usage_metadata.get("output_token_details", {}).get("flex")
|
||||
assert isinstance(flex_output, int)
|
||||
assert flex_output > 0
|
||||
# GPT-5-nano/reasoning model specific. Remove if model used in test changes.
|
||||
flex_reasoning = usage_metadata.get("output_token_details", {}).get(
|
||||
"flex_reasoning"
|
||||
)
|
||||
assert isinstance(flex_reasoning, int)
|
||||
assert flex_reasoning > 0
|
||||
assert flex_reasoning + flex_output == usage_metadata.get("output_tokens")
|
||||
|
||||
|
||||
@pytest.mark.flaky(retries=3, delay=1)
|
||||
def test_stream() -> None:
|
||||
"""Test streaming tokens from OpenAI."""
|
||||
llm = ChatOpenAI(model="gpt-4.1-mini")
|
||||
llm = ChatOpenAI(
|
||||
model="gpt-5-nano",
|
||||
service_tier="flex", # Also test service_tier
|
||||
max_retries=3, # Add retries for 503 capacity errors
|
||||
)
|
||||
|
||||
full: BaseMessageChunk | None = None
|
||||
for chunk in llm.stream("I'm Pickle Rick"):
|
||||
@@ -236,6 +258,19 @@ def test_stream() -> None:
|
||||
assert aggregate.usage_metadata["input_tokens"] > 0
|
||||
assert aggregate.usage_metadata["output_tokens"] > 0
|
||||
assert aggregate.usage_metadata["total_tokens"] > 0
|
||||
assert aggregate.usage_metadata.get("input_token_details", {}).get("flex", 0) > 0 # type: ignore[operator]
|
||||
assert aggregate.usage_metadata.get("output_token_details", {}).get("flex", 0) > 0 # type: ignore[operator]
|
||||
assert (
|
||||
aggregate.usage_metadata.get("output_token_details", {}).get( # type: ignore[operator]
|
||||
"flex_reasoning", 0
|
||||
)
|
||||
> 0
|
||||
)
|
||||
assert aggregate.usage_metadata.get("output_token_details", {}).get( # type: ignore[operator]
|
||||
"flex_reasoning", 0
|
||||
) + aggregate.usage_metadata.get("output_token_details", {}).get(
|
||||
"flex", 0
|
||||
) == aggregate.usage_metadata.get("output_tokens")
|
||||
|
||||
|
||||
async def test_astream() -> None:
|
||||
@@ -308,6 +343,28 @@ async def test_astream() -> None:
|
||||
await _test_stream(llm.astream("Hello", stream_usage=False), expect_usage=False)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("streaming", [False, True])
|
||||
def test_flex_usage_responses(streaming: bool) -> None:
|
||||
llm = ChatOpenAI(
|
||||
model="gpt-5-nano",
|
||||
service_tier="flex",
|
||||
max_retries=3,
|
||||
use_responses_api=True,
|
||||
streaming=streaming,
|
||||
)
|
||||
result = llm.invoke("Hello")
|
||||
assert result.usage_metadata
|
||||
flex_input = result.usage_metadata.get("input_token_details", {}).get("flex")
|
||||
flex_output = result.usage_metadata.get("output_token_details", {}).get("flex")
|
||||
flex_reasoning = result.usage_metadata.get("output_token_details", {}).get(
|
||||
"flex_reasoning"
|
||||
)
|
||||
assert isinstance(flex_input, int)
|
||||
assert isinstance(flex_output, int)
|
||||
assert isinstance(flex_reasoning, int)
|
||||
assert flex_output + flex_reasoning == result.usage_metadata.get("output_tokens")
|
||||
|
||||
|
||||
async def test_abatch_tags() -> None:
|
||||
"""Test batch tokens from ChatOpenAI."""
|
||||
llm = ChatOpenAI()
|
||||
|
||||
Reference in New Issue
Block a user