From abf6a262dc5a210c1f7d62617154eeebdc38253e Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Mon, 4 Jul 2022 15:21:26 +0800 Subject: [PATCH] [fx] added module patch for pooling layers (#1197) --- colossalai/fx/tracer/meta_patch/__init__.py | 1 - .../fx/tracer/meta_patch/patched_module.py | 30 ++++++++++++++++++ .../test_tracer/test_non_patched_module.py | 31 +++++++++++++++++++ .../test_tracer/test_patched_module.py | 30 ++++++++++++++++++ 4 files changed, 91 insertions(+), 1 deletion(-) create mode 100644 tests/test_fx/test_tracer/test_non_patched_module.py diff --git a/colossalai/fx/tracer/meta_patch/__init__.py b/colossalai/fx/tracer/meta_patch/__init__.py index 5f0c4745b..28b54b9bb 100644 --- a/colossalai/fx/tracer/meta_patch/__init__.py +++ b/colossalai/fx/tracer/meta_patch/__init__.py @@ -1,4 +1,3 @@ -from sys import meta_path from .registry import * from .patched_function import * from .patched_module import * diff --git a/colossalai/fx/tracer/meta_patch/patched_module.py b/colossalai/fx/tracer/meta_patch/patched_module.py index 2eff50882..e3ece40df 100644 --- a/colossalai/fx/tracer/meta_patch/patched_module.py +++ b/colossalai/fx/tracer/meta_patch/patched_module.py @@ -86,3 +86,33 @@ def torch_nn_conv3d(self, input): w_out, ) return torch.empty(result_shape, device='meta') + + +@meta_patched_module.register(torch.nn.MaxPool3d) +def torch_nn_maxpool3d(self, input): + num_dim = input.dim() + assert num_dim in [4, 5], f'expected the input to have 4 or 5 dimensions, but got {num_dim} dimensions' + + d_in, h_in, w_in = input.shape[-3:] + + def _convert_int_to_list(item): + if isinstance(item, int): + return [item] * 3 + else: + return item + + padding = _convert_int_to_list(self.padding) + dilation = _convert_int_to_list(self.dilation) + kernel_size = _convert_int_to_list(self.kernel_size) + stride = _convert_int_to_list(self.stride) + + d_out = math.floor((d_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) / stride[0] + 1) + h_out = math.floor((h_in + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) / stride[1] + 1) + w_out = math.floor((w_in + 2 * padding[2] - dilation[2] * (kernel_size[2] - 1) - 1) / stride[2] + 1) + + result_shape = input.shape[:-3] + ( + d_out, + h_out, + w_out, + ) + return torch.empty(result_shape, device='meta') diff --git a/tests/test_fx/test_tracer/test_non_patched_module.py b/tests/test_fx/test_tracer/test_non_patched_module.py new file mode 100644 index 000000000..9abc964b0 --- /dev/null +++ b/tests/test_fx/test_tracer/test_non_patched_module.py @@ -0,0 +1,31 @@ +import torch +import torch.nn + + +def test_maxpool(): + layer_to_test = dict(maxpool_1d=dict(layer=torch.nn.MaxPool1d, shape=(4, 3, 4)), + maxpool_2d=dict(layer=torch.nn.MaxPool2d, shape=(4, 3, 4, 4))) + + for name, info in layer_to_test.items(): + data = torch.rand(*info['shape']) + meta_data = data.to('meta') + layer = info['layer'](kernel_size=3) + out = layer(data) + meta_out = layer(meta_data) + assert meta_out.is_meta + assert out.shape == meta_out.shape + + +def test_avgpool(): + layer_to_test = dict(maxpool_1d=dict(layer=torch.nn.AvgPool1d, shape=(4, 3, 4)), + maxpool_2d=dict(layer=torch.nn.AvgPool2d, shape=(4, 3, 4, 4)), + maxpool_3d=dict(layer=torch.nn.AvgPool3d, shape=(4, 3, 4, 4, 4))) + + for name, info in layer_to_test.items(): + data = torch.rand(*info['shape']) + meta_data = data.to('meta') + layer = info['layer'](kernel_size=3) + out = layer(data) + meta_out = layer(meta_data) + assert meta_out.is_meta + assert out.shape == meta_out.shape diff --git a/tests/test_fx/test_tracer/test_patched_module.py b/tests/test_fx/test_tracer/test_patched_module.py index 0cb38f436..d96bc04ac 100644 --- a/tests/test_fx/test_tracer/test_patched_module.py +++ b/tests/test_fx/test_tracer/test_patched_module.py @@ -225,3 +225,33 @@ def test_conv3d(): patch_fn=patched_module.torch_nn_conv3d, expect_exception=False, output_shape=materialized_output.shape) + + +def test_maxpool3d(): + pooler = torch.nn.MaxPool3d(kernel_size=3) + + # test max pool 3d + data = torch.rand(2, 3, 4, 4, 4) + materialized_output = pooler(data) + _assert_output_shape(data=data, + module=pooler, + patch_fn=patched_module.torch_nn_maxpool3d, + expect_exception=False, + output_shape=materialized_output.shape) + + # test max pool 3d + data = torch.rand(2, 3, 4, 4) + materialized_output = pooler(data) + _assert_output_shape(data=data, + module=pooler, + patch_fn=patched_module.torch_nn_maxpool3d, + expect_exception=False, + output_shape=materialized_output.shape) + + # test max pool 3d + data = torch.rand(2, 3, 4) + _assert_output_shape(data=data, + module=pooler, + patch_fn=patched_module.torch_nn_maxpool3d, + expect_exception=True, + output_shape=None)