[fx] fixed adapative pooling size concatenation error (#1489)

This commit is contained in:
Frank Lee
2022-08-25 09:05:07 +08:00
committed by GitHub
parent cde7b8a5b8
commit 3da68d6b1b
2 changed files with 98 additions and 17 deletions

View File

@@ -407,3 +407,76 @@ def test_pool3d():
# test max pool 3d
data = torch.rand(2, 3, 4)
_assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=True, output_shape=None)
# adapative pooling is different from other pooling, so test it individually
def test_adaptive_pooling_1d():
pooler = torch.nn.AdaptiveAvgPool1d(output_size=3)
patch_func = patched_module.torch_nn_adapative_pooling_1d
data = torch.rand(3, 4)
output = pooler(data)
_assert_output_shape(data=data,
module=pooler,
patch_fn=patch_func,
expect_exception=False,
output_shape=output.shape)
data = torch.rand(2, 3, 4)
output = pooler(data)
_assert_output_shape(data=data,
module=pooler,
patch_fn=patch_func,
expect_exception=False,
output_shape=output.shape)
data = torch.rand(2, 3, 4, 5)
_assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=True, output_shape=None)
def test_adaptive_pooling_2d():
pooler = torch.nn.AdaptiveAvgPool2d(output_size=3)
patch_func = patched_module.torch_nn_adapative_pooling_2d
data = torch.rand(3, 4)
_assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=True, output_shape=None)
data = torch.rand(2, 3, 4)
output = pooler(data)
_assert_output_shape(data=data,
module=pooler,
patch_fn=patch_func,
expect_exception=False,
output_shape=output.shape)
data = torch.rand(2, 3, 4, 5)
output = pooler(data)
_assert_output_shape(data=data,
module=pooler,
patch_fn=patch_func,
expect_exception=False,
output_shape=output.shape)
def test_adaptive_pooling_3d():
pooler = torch.nn.AdaptiveAvgPool3d(output_size=3)
patch_func = patched_module.torch_nn_adapative_pooling_3d
data = torch.rand(3, 4, 5)
_assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=True, output_shape=None)
data = torch.rand(2, 3, 4, 5)
output = pooler(data)
_assert_output_shape(data=data,
module=pooler,
patch_fn=patch_func,
expect_exception=False,
output_shape=output.shape)
data = torch.rand(2, 3, 4, 5, 6)
output = pooler(data)
_assert_output_shape(data=data,
module=pooler,
patch_fn=patch_func,
expect_exception=False,
output_shape=output.shape)