mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 17:46:42 +00:00
[fx] fixed adapative pooling size concatenation error (#1489)
This commit is contained in:
@@ -407,3 +407,76 @@ def test_pool3d():
|
||||
# test max pool 3d
|
||||
data = torch.rand(2, 3, 4)
|
||||
_assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=True, output_shape=None)
|
||||
|
||||
|
||||
# adapative pooling is different from other pooling, so test it individually
|
||||
def test_adaptive_pooling_1d():
|
||||
pooler = torch.nn.AdaptiveAvgPool1d(output_size=3)
|
||||
patch_func = patched_module.torch_nn_adapative_pooling_1d
|
||||
|
||||
data = torch.rand(3, 4)
|
||||
output = pooler(data)
|
||||
_assert_output_shape(data=data,
|
||||
module=pooler,
|
||||
patch_fn=patch_func,
|
||||
expect_exception=False,
|
||||
output_shape=output.shape)
|
||||
|
||||
data = torch.rand(2, 3, 4)
|
||||
output = pooler(data)
|
||||
_assert_output_shape(data=data,
|
||||
module=pooler,
|
||||
patch_fn=patch_func,
|
||||
expect_exception=False,
|
||||
output_shape=output.shape)
|
||||
|
||||
data = torch.rand(2, 3, 4, 5)
|
||||
_assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=True, output_shape=None)
|
||||
|
||||
|
||||
def test_adaptive_pooling_2d():
|
||||
pooler = torch.nn.AdaptiveAvgPool2d(output_size=3)
|
||||
patch_func = patched_module.torch_nn_adapative_pooling_2d
|
||||
|
||||
data = torch.rand(3, 4)
|
||||
_assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=True, output_shape=None)
|
||||
|
||||
data = torch.rand(2, 3, 4)
|
||||
output = pooler(data)
|
||||
_assert_output_shape(data=data,
|
||||
module=pooler,
|
||||
patch_fn=patch_func,
|
||||
expect_exception=False,
|
||||
output_shape=output.shape)
|
||||
|
||||
data = torch.rand(2, 3, 4, 5)
|
||||
output = pooler(data)
|
||||
_assert_output_shape(data=data,
|
||||
module=pooler,
|
||||
patch_fn=patch_func,
|
||||
expect_exception=False,
|
||||
output_shape=output.shape)
|
||||
|
||||
|
||||
def test_adaptive_pooling_3d():
|
||||
pooler = torch.nn.AdaptiveAvgPool3d(output_size=3)
|
||||
patch_func = patched_module.torch_nn_adapative_pooling_3d
|
||||
|
||||
data = torch.rand(3, 4, 5)
|
||||
_assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=True, output_shape=None)
|
||||
|
||||
data = torch.rand(2, 3, 4, 5)
|
||||
output = pooler(data)
|
||||
_assert_output_shape(data=data,
|
||||
module=pooler,
|
||||
patch_fn=patch_func,
|
||||
expect_exception=False,
|
||||
output_shape=output.shape)
|
||||
|
||||
data = torch.rand(2, 3, 4, 5, 6)
|
||||
output = pooler(data)
|
||||
_assert_output_shape(data=data,
|
||||
module=pooler,
|
||||
patch_fn=patch_func,
|
||||
expect_exception=False,
|
||||
output_shape=output.shape)
|
||||
|
Reference in New Issue
Block a user