diff --git a/libs/partners/xai/langchain_xai/chat_models.py b/libs/partners/xai/langchain_xai/chat_models.py index 8e5d5c7b948..00a99c4e6a5 100644 --- a/libs/partners/xai/langchain_xai/chat_models.py +++ b/libs/partners/xai/langchain_xai/chat_models.py @@ -575,6 +575,21 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override] response.citations ) + # Unlike OpenAI, xAI reports reasoning tokens < completion tokens. So we assume + # they are not counted in output tokens, and we add them here. + if ( + (not self._use_responses_api({})) + and (usage_metadata := rtn.generations[0].message.usage_metadata) # type: ignore[attr-defined] + and ( + reasoning_tokens := usage_metadata.get("output_token_details", {}).get( + "reasoning" + ) + ) + ): + rtn.generations[0].message.usage_metadata["output_tokens"] += ( # type: ignore[attr-defined] + reasoning_tokens + ) + return rtn def _convert_chunk_to_generation_chunk( @@ -609,6 +624,19 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override] ): generation_chunk.message.additional_kwargs["citations"] = citations + # Unlike OpenAI, xAI reports reasoning tokens < completion tokens. So we assume + # they are not counted in output tokens, and we add them here. + if ( + generation_chunk + and (not self._use_responses_api({})) + and (usage_metadata := generation_chunk.message.usage_metadata) # type: ignore[attr-defined] + and ( + reasoning_tokens := usage_metadata.get("output_token_details", {}).get( + "reasoning" + ) + ) + ): + generation_chunk.message.usage_metadata["output_tokens"] += reasoning_tokens # type: ignore[attr-defined] return generation_chunk def with_structured_output( diff --git a/libs/partners/xai/tests/integration_tests/test_chat_models.py b/libs/partners/xai/tests/integration_tests/test_chat_models.py index 72d17fc9d19..6a192b5dc99 100644 --- a/libs/partners/xai/tests/integration_tests/test_chat_models.py +++ b/libs/partners/xai/tests/integration_tests/test_chat_models.py @@ -37,6 +37,15 @@ def test_reasoning(output_version: Literal["", "v1"]) -> None: assert response.content assert response.additional_kwargs["reasoning_content"] + ## Check output tokens + usage_metadata = response.usage_metadata + assert usage_metadata + reasoning_tokens = usage_metadata.get("output_token_details", {}).get("reasoning") + total_tokens = usage_metadata.get("output_tokens") + assert total_tokens + assert reasoning_tokens + assert total_tokens > reasoning_tokens + # Test streaming full: BaseMessageChunk | None = None for chunk in chat_model.stream(input_message): @@ -44,6 +53,15 @@ def test_reasoning(output_version: Literal["", "v1"]) -> None: assert isinstance(full, AIMessageChunk) assert full.additional_kwargs["reasoning_content"] + ## Check output tokens + usage_metadata = full.usage_metadata + assert usage_metadata + reasoning_tokens = usage_metadata.get("output_token_details", {}).get("reasoning") + total_tokens = usage_metadata.get("output_tokens") + assert total_tokens + assert reasoning_tokens + assert total_tokens > reasoning_tokens + # Check that we can access reasoning content blocks assert response.content_blocks reasoning_content = (