From 3da68d6b1b2a3f694a795bc44d61aab08e266e47 Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Thu, 25 Aug 2022 09:05:07 +0800 Subject: [PATCH] [fx] fixed adapative pooling size concatenation error (#1489) --- .../meta_patch/patched_module/pooling.py | 42 ++++++----- .../test_tracer/test_patched_module.py | 73 +++++++++++++++++++ 2 files changed, 98 insertions(+), 17 deletions(-) diff --git a/colossalai/fx/tracer/meta_patch/patched_module/pooling.py b/colossalai/fx/tracer/meta_patch/patched_module/pooling.py index a336120f5..f740f8511 100644 --- a/colossalai/fx/tracer/meta_patch/patched_module/pooling.py +++ b/colossalai/fx/tracer/meta_patch/patched_module/pooling.py @@ -22,7 +22,7 @@ def torch_nn_avgpool1d(self, input): l_out = math.floor((l_in + 2 * padding[0] - kernel_size[0]) / stride[0] + 1) - result_shape = input.shape[:-1] + (l_out,) + result_shape = tuple(input.shape[:-1]) + (l_out,) return torch.empty(result_shape, device='meta') @@ -46,7 +46,7 @@ def torch_nn_avgpool2d(self, input): h_out = math.floor((h_in + 2 * padding[0] - kernel_size[0]) / stride[0] + 1) w_out = math.floor((w_in + 2 * padding[1] - kernel_size[1]) / stride[1] + 1) - result_shape = input.shape[:-2] + ( + result_shape = tuple(input.shape[:-2]) + ( h_out, w_out, ) @@ -74,7 +74,7 @@ def torch_nn_avgpool3d(self, input): h_out = math.floor((h_in + 2 * padding[1] - kernel_size[1]) / stride[1] + 1) w_out = math.floor((w_in + 2 * padding[2] - kernel_size[2]) / stride[2] + 1) - result_shape = input.shape[:-3] + ( + result_shape = tuple(input.shape[:-3]) + ( d_out, h_out, w_out, @@ -102,7 +102,7 @@ def torch_nn_maxpool1d(self, input): l_out = math.floor((l_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) / stride[0] + 1) - result_shape = input.shape[:-1] + (l_out,) + result_shape = tuple(input.shape[:-1]) + (l_out,) return torch.empty(result_shape, device='meta') @@ -127,7 +127,7 @@ def torch_nn_maxpool2d(self, input): h_out = math.floor((h_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) / stride[0] + 1) w_out = math.floor((w_in + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) / stride[1] + 1) - result_shape = input.shape[:-2] + ( + result_shape = tuple(input.shape[:-2]) + ( h_out, w_out, ) @@ -156,7 +156,7 @@ def torch_nn_maxpool3d(self, input): h_out = math.floor((h_in + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) / stride[1] + 1) w_out = math.floor((w_in + 2 * padding[2] - dilation[2] * (kernel_size[2] - 1) - 1) / stride[2] + 1) - result_shape = input.shape[:-3] + ( + result_shape = tuple(input.shape[:-3]) + ( d_out, h_out, w_out, @@ -167,26 +167,34 @@ def torch_nn_maxpool3d(self, input): @meta_patched_module.register(torch.nn.AdaptiveAvgPool1d) @meta_patched_module.register(torch.nn.AdaptiveMaxPool1d) def torch_nn_adapative_pooling_1d(self, input): - result_shape = input.shape[:-1] + (self.output_size,) + assert input.dim() in [2, 3] + if isinstance(self.output_size, int): + output_size = (self.output_size,) + else: + output_size = self.output_size + result_shape = tuple(input.shape[:-1]) + output_size return torch.empty(result_shape, device='meta') @meta_patched_module.register(torch.nn.AdaptiveAvgPool2d) @meta_patched_module.register(torch.nn.AdaptiveMaxPool2d) def torch_nn_adapative_pooling_2d(self, input): - result_shape = input.shape[:-2] + ( - self.output_size, - self.output_size, - ) + assert input.dim() in [3, 4] + if isinstance(self.output_size, int): + output_size = (self.output_size,) * 2 + else: + output_size = self.output_size + result_shape = tuple(input.shape[:-2]) + output_size return torch.empty(result_shape, device='meta') @meta_patched_module.register(torch.nn.AdaptiveAvgPool3d) @meta_patched_module.register(torch.nn.AdaptiveMaxPool3d) def torch_nn_adapative_pooling_3d(self, input): - result_shape = input.shape[:-3] + ( - self.output_size, - self.output_size, - self.output_size, - ) - return torch.empty(result_shape, device='meta') \ No newline at end of file + assert input.dim() in [4, 5] + if isinstance(self.output_size, int): + output_size = (self.output_size,) * 3 + else: + output_size = self.output_size + result_shape = tuple(input.shape[:-3]) + output_size + return torch.empty(result_shape, device='meta') diff --git a/tests/test_fx/test_tracer/test_patched_module.py b/tests/test_fx/test_tracer/test_patched_module.py index 82389ae41..94a93e16f 100644 --- a/tests/test_fx/test_tracer/test_patched_module.py +++ b/tests/test_fx/test_tracer/test_patched_module.py @@ -407,3 +407,76 @@ def test_pool3d(): # test max pool 3d data = torch.rand(2, 3, 4) _assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=True, output_shape=None) + + +# adapative pooling is different from other pooling, so test it individually +def test_adaptive_pooling_1d(): + pooler = torch.nn.AdaptiveAvgPool1d(output_size=3) + patch_func = patched_module.torch_nn_adapative_pooling_1d + + data = torch.rand(3, 4) + output = pooler(data) + _assert_output_shape(data=data, + module=pooler, + patch_fn=patch_func, + expect_exception=False, + output_shape=output.shape) + + data = torch.rand(2, 3, 4) + output = pooler(data) + _assert_output_shape(data=data, + module=pooler, + patch_fn=patch_func, + expect_exception=False, + output_shape=output.shape) + + data = torch.rand(2, 3, 4, 5) + _assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=True, output_shape=None) + + +def test_adaptive_pooling_2d(): + pooler = torch.nn.AdaptiveAvgPool2d(output_size=3) + patch_func = patched_module.torch_nn_adapative_pooling_2d + + data = torch.rand(3, 4) + _assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=True, output_shape=None) + + data = torch.rand(2, 3, 4) + output = pooler(data) + _assert_output_shape(data=data, + module=pooler, + patch_fn=patch_func, + expect_exception=False, + output_shape=output.shape) + + data = torch.rand(2, 3, 4, 5) + output = pooler(data) + _assert_output_shape(data=data, + module=pooler, + patch_fn=patch_func, + expect_exception=False, + output_shape=output.shape) + + +def test_adaptive_pooling_3d(): + pooler = torch.nn.AdaptiveAvgPool3d(output_size=3) + patch_func = patched_module.torch_nn_adapative_pooling_3d + + data = torch.rand(3, 4, 5) + _assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=True, output_shape=None) + + data = torch.rand(2, 3, 4, 5) + output = pooler(data) + _assert_output_shape(data=data, + module=pooler, + patch_fn=patch_func, + expect_exception=False, + output_shape=output.shape) + + data = torch.rand(2, 3, 4, 5, 6) + output = pooler(data) + _assert_output_shape(data=data, + module=pooler, + patch_fn=patch_func, + expect_exception=False, + output_shape=output.shape)