[fx] added module patch for pooling layers (#1197)

This commit is contained in:
Frank Lee
2022-07-04 15:21:26 +08:00
committed by GitHub
parent 23442a5bc1
commit abf6a262dc
4 changed files with 91 additions and 1 deletions

View File

@@ -225,3 +225,33 @@ def test_conv3d():
patch_fn=patched_module.torch_nn_conv3d,
expect_exception=False,
output_shape=materialized_output.shape)
def test_maxpool3d():
pooler = torch.nn.MaxPool3d(kernel_size=3)
# test max pool 3d
data = torch.rand(2, 3, 4, 4, 4)
materialized_output = pooler(data)
_assert_output_shape(data=data,
module=pooler,
patch_fn=patched_module.torch_nn_maxpool3d,
expect_exception=False,
output_shape=materialized_output.shape)
# test max pool 3d
data = torch.rand(2, 3, 4, 4)
materialized_output = pooler(data)
_assert_output_shape(data=data,
module=pooler,
patch_fn=patched_module.torch_nn_maxpool3d,
expect_exception=False,
output_shape=materialized_output.shape)
# test max pool 3d
data = torch.rand(2, 3, 4)
_assert_output_shape(data=data,
module=pooler,
patch_fn=patched_module.torch_nn_maxpool3d,
expect_exception=True,
output_shape=None)