diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index 8e8a3569c5b..9e225788403 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -917,7 +917,9 @@ class BaseChatOpenAI(BaseChatModel): ) usage_metadata: UsageMetadata | None = ( - _create_usage_metadata(token_usage) if token_usage else None + _create_usage_metadata(token_usage, chunk.get("service_tier")) + if token_usage + else None ) if len(choices) == 0: # logprobs is implicitly None @@ -1289,11 +1291,14 @@ class BaseChatOpenAI(BaseChatModel): raise TypeError(msg) token_usage = response_dict.get("usage") + service_tier = response_dict.get("service_tier") for res in choices: message = _convert_dict_to_message(res["message"]) if token_usage and isinstance(message, AIMessage): - message.usage_metadata = _create_usage_metadata(token_usage) + message.usage_metadata = _create_usage_metadata( + token_usage, service_tier + ) generation_info = generation_info or {} generation_info["finish_reason"] = ( res.get("finish_reason") @@ -1312,8 +1317,8 @@ class BaseChatOpenAI(BaseChatModel): } if "id" in response_dict: llm_output["id"] = response_dict["id"] - if "service_tier" in response_dict: - llm_output["service_tier"] = response_dict["service_tier"] + if service_tier: + llm_output["service_tier"] = service_tier if isinstance(response, openai.BaseModel) and getattr( response, "choices", None @@ -3387,26 +3392,40 @@ class OpenAIRefusalError(Exception): """ -def _create_usage_metadata(oai_token_usage: dict) -> UsageMetadata: +def _create_usage_metadata( + oai_token_usage: dict, service_tier: str | None = None +) -> UsageMetadata: input_tokens = oai_token_usage.get("prompt_tokens") or 0 output_tokens = oai_token_usage.get("completion_tokens") or 0 total_tokens = oai_token_usage.get("total_tokens") or input_tokens + output_tokens + if service_tier not in {"priority", "flex"}: + service_tier = None + service_tier_prefix = f"{service_tier}_" if service_tier else "" input_token_details: dict = { "audio": (oai_token_usage.get("prompt_tokens_details") or {}).get( "audio_tokens" ), - "cache_read": (oai_token_usage.get("prompt_tokens_details") or {}).get( - "cached_tokens" - ), + f"{service_tier_prefix}cache_read": ( + oai_token_usage.get("prompt_tokens_details") or {} + ).get("cached_tokens"), } output_token_details: dict = { "audio": (oai_token_usage.get("completion_tokens_details") or {}).get( "audio_tokens" ), - "reasoning": (oai_token_usage.get("completion_tokens_details") or {}).get( - "reasoning_tokens" - ), + f"{service_tier_prefix}reasoning": ( + oai_token_usage.get("completion_tokens_details") or {} + ).get("reasoning_tokens"), } + if service_tier is not None: + # Avoid counting cache and reasoning tokens towards the service tier token + # counts, since service tier tokens are already priced differently + input_token_details[service_tier] = input_tokens - input_token_details.get( + f"{service_tier_prefix}cache_read", 0 + ) + output_token_details[service_tier] = output_tokens - output_token_details.get( + f"{service_tier_prefix}reasoning", 0 + ) return UsageMetadata( input_tokens=input_tokens, output_tokens=output_tokens, @@ -3420,20 +3439,34 @@ def _create_usage_metadata(oai_token_usage: dict) -> UsageMetadata: ) -def _create_usage_metadata_responses(oai_token_usage: dict) -> UsageMetadata: +def _create_usage_metadata_responses( + oai_token_usage: dict, service_tier: str | None = None +) -> UsageMetadata: input_tokens = oai_token_usage.get("input_tokens", 0) output_tokens = oai_token_usage.get("output_tokens", 0) total_tokens = oai_token_usage.get("total_tokens", input_tokens + output_tokens) + if service_tier not in {"priority", "flex"}: + service_tier = None + service_tier_prefix = f"{service_tier}_" if service_tier else "" output_token_details: dict = { - "reasoning": (oai_token_usage.get("output_tokens_details") or {}).get( - "reasoning_tokens" - ) + f"{service_tier_prefix}reasoning": ( + oai_token_usage.get("output_tokens_details") or {} + ).get("reasoning_tokens") } input_token_details: dict = { - "cache_read": (oai_token_usage.get("input_tokens_details") or {}).get( - "cached_tokens" - ) + f"{service_tier_prefix}cache_read": ( + oai_token_usage.get("input_tokens_details") or {} + ).get("cached_tokens") } + if service_tier is not None: + # Avoid counting cache and reasoning tokens towards the service tier token + # counts, since service tier tokens are already priced differently + output_token_details[service_tier] = output_tokens - output_token_details.get( + f"{service_tier_prefix}reasoning", 0 + ) + input_token_details[service_tier] = input_tokens - input_token_details.get( + f"{service_tier_prefix}cache_read", 0 + ) return UsageMetadata( input_tokens=input_tokens, output_tokens=output_tokens, @@ -3957,7 +3990,9 @@ def _construct_lc_result_from_responses_api( response_metadata["model_provider"] = "openai" response_metadata["model_name"] = response_metadata.get("model") if response.usage: - usage_metadata = _create_usage_metadata_responses(response.usage.model_dump()) + usage_metadata = _create_usage_metadata_responses( + response.usage.model_dump(), response.service_tier + ) else: usage_metadata = None diff --git a/libs/partners/openai/tests/integration_tests/chat_models/test_base.py b/libs/partners/openai/tests/integration_tests/chat_models/test_base.py index edfeae36e35..945e83124d3 100644 --- a/libs/partners/openai/tests/integration_tests/chat_models/test_base.py +++ b/libs/partners/openai/tests/integration_tests/chat_models/test_base.py @@ -194,13 +194,35 @@ def test_openai_invoke() -> None: result = llm.invoke("Hello", config={"tags": ["foo"]}) assert isinstance(result.content, str) + usage_metadata = result.usage_metadata # type: ignore[attr-defined] + # assert no response headers if include_response_headers is not set assert "headers" not in result.response_metadata + assert usage_metadata is not None + flex_input = usage_metadata.get("input_token_details", {}).get("flex") + assert isinstance(flex_input, int) + assert flex_input > 0 + assert flex_input == usage_metadata.get("input_tokens") + flex_output = usage_metadata.get("output_token_details", {}).get("flex") + assert isinstance(flex_output, int) + assert flex_output > 0 + # GPT-5-nano/reasoning model specific. Remove if model used in test changes. + flex_reasoning = usage_metadata.get("output_token_details", {}).get( + "flex_reasoning" + ) + assert isinstance(flex_reasoning, int) + assert flex_reasoning > 0 + assert flex_reasoning + flex_output == usage_metadata.get("output_tokens") +@pytest.mark.flaky(retries=3, delay=1) def test_stream() -> None: """Test streaming tokens from OpenAI.""" - llm = ChatOpenAI(model="gpt-4.1-mini") + llm = ChatOpenAI( + model="gpt-5-nano", + service_tier="flex", # Also test service_tier + max_retries=3, # Add retries for 503 capacity errors + ) full: BaseMessageChunk | None = None for chunk in llm.stream("I'm Pickle Rick"): @@ -236,6 +258,19 @@ def test_stream() -> None: assert aggregate.usage_metadata["input_tokens"] > 0 assert aggregate.usage_metadata["output_tokens"] > 0 assert aggregate.usage_metadata["total_tokens"] > 0 + assert aggregate.usage_metadata.get("input_token_details", {}).get("flex", 0) > 0 # type: ignore[operator] + assert aggregate.usage_metadata.get("output_token_details", {}).get("flex", 0) > 0 # type: ignore[operator] + assert ( + aggregate.usage_metadata.get("output_token_details", {}).get( # type: ignore[operator] + "flex_reasoning", 0 + ) + > 0 + ) + assert aggregate.usage_metadata.get("output_token_details", {}).get( # type: ignore[operator] + "flex_reasoning", 0 + ) + aggregate.usage_metadata.get("output_token_details", {}).get( + "flex", 0 + ) == aggregate.usage_metadata.get("output_tokens") async def test_astream() -> None: @@ -308,6 +343,28 @@ async def test_astream() -> None: await _test_stream(llm.astream("Hello", stream_usage=False), expect_usage=False) +@pytest.mark.parametrize("streaming", [False, True]) +def test_flex_usage_responses(streaming: bool) -> None: + llm = ChatOpenAI( + model="gpt-5-nano", + service_tier="flex", + max_retries=3, + use_responses_api=True, + streaming=streaming, + ) + result = llm.invoke("Hello") + assert result.usage_metadata + flex_input = result.usage_metadata.get("input_token_details", {}).get("flex") + flex_output = result.usage_metadata.get("output_token_details", {}).get("flex") + flex_reasoning = result.usage_metadata.get("output_token_details", {}).get( + "flex_reasoning" + ) + assert isinstance(flex_input, int) + assert isinstance(flex_output, int) + assert isinstance(flex_reasoning, int) + assert flex_output + flex_reasoning == result.usage_metadata.get("output_tokens") + + async def test_abatch_tags() -> None: """Test batch tokens from ChatOpenAI.""" llm = ChatOpenAI()