[fx] added module patch for pooling layers (#1197)

This commit is contained in:
Frank Lee
2022-07-04 15:21:26 +08:00
committed by GitHub
parent 23442a5bc1
commit abf6a262dc
4 changed files with 91 additions and 1 deletions

View File

@@ -0,0 +1,31 @@
import torch
import torch.nn
def test_maxpool():
layer_to_test = dict(maxpool_1d=dict(layer=torch.nn.MaxPool1d, shape=(4, 3, 4)),
maxpool_2d=dict(layer=torch.nn.MaxPool2d, shape=(4, 3, 4, 4)))
for name, info in layer_to_test.items():
data = torch.rand(*info['shape'])
meta_data = data.to('meta')
layer = info['layer'](kernel_size=3)
out = layer(data)
meta_out = layer(meta_data)
assert meta_out.is_meta
assert out.shape == meta_out.shape
def test_avgpool():
layer_to_test = dict(maxpool_1d=dict(layer=torch.nn.AvgPool1d, shape=(4, 3, 4)),
maxpool_2d=dict(layer=torch.nn.AvgPool2d, shape=(4, 3, 4, 4)),
maxpool_3d=dict(layer=torch.nn.AvgPool3d, shape=(4, 3, 4, 4, 4)))
for name, info in layer_to_test.items():
data = torch.rand(*info['shape'])
meta_data = data.to('meta')
layer = info['layer'](kernel_size=3)
out = layer(data)
meta_out = layer(meta_data)
assert meta_out.is_meta
assert out.shape == meta_out.shape

View File

@@ -225,3 +225,33 @@ def test_conv3d():
patch_fn=patched_module.torch_nn_conv3d,
expect_exception=False,
output_shape=materialized_output.shape)
def test_maxpool3d():
pooler = torch.nn.MaxPool3d(kernel_size=3)
# test max pool 3d
data = torch.rand(2, 3, 4, 4, 4)
materialized_output = pooler(data)
_assert_output_shape(data=data,
module=pooler,
patch_fn=patched_module.torch_nn_maxpool3d,
expect_exception=False,
output_shape=materialized_output.shape)
# test max pool 3d
data = torch.rand(2, 3, 4, 4)
materialized_output = pooler(data)
_assert_output_shape(data=data,
module=pooler,
patch_fn=patched_module.torch_nn_maxpool3d,
expect_exception=False,
output_shape=materialized_output.shape)
# test max pool 3d
data = torch.rand(2, 3, 4)
_assert_output_shape(data=data,
module=pooler,
patch_fn=patched_module.torch_nn_maxpool3d,
expect_exception=True,
output_shape=None)