[fx] add torchaudio test (#1369)

* [fx]add torchaudio test

* [fx]add torchaudio test

* [fx] add torchaudio test

* [fx] add torchaudio test

* [fx] add torchaudio test

* [fx] add torchaudio test

* [fx] add torchaudio test

* [fx] add torchaudio test and test patches

* Delete ~

* [fx] add patches and patches test

* [fx] add patches and patches test

* [fx] fix patches

* [fx] fix rnn patches

* [fx] fix rnn patches

* [fx] fix rnn patches

* [fx] fix rnn patches

* [fx] merge upstream

* [fx] fix import errors
This commit is contained in:
Super Daniel
2022-07-27 11:03:14 +08:00
committed by GitHub
parent fb6f085907
commit be229217ce
18 changed files with 609 additions and 16 deletions

View File

@@ -4,7 +4,12 @@ from colossalai.fx.tracer.meta_patch import patched_module
def _run(data, module, patch_fn):
try:
output = patch_fn(module, data)
if isinstance(data, dict):
output = patch_fn(module, **data)
if isinstance(data, tuple) or isinstance(data, list):
output = patch_fn(module, *data)
else:
output = patch_fn(module, data)
return output
except Exception as e:
return e
@@ -17,8 +22,13 @@ def _assert_output_shape(data, module, patch_fn, expect_exception, output_shape)
assert isinstance(output, AssertionError)
else:
assert not isinstance(output, Exception)
assert output.is_meta
assert output.shape == output_shape
if isinstance(output, tuple):
for item, shape in zip(output, output_shape):
assert item.is_meta
assert item.shape == shape
else:
assert output.is_meta
assert output.shape == output_shape
def test_linear():
@@ -27,11 +37,27 @@ def test_linear():
module = torch.nn.Linear(4, 2)
_assert_output_shape(data, module, patched_module.torch_nn_linear, False, torch.Size([2, 2]))
# Test if the linear patch can catch exception when dimension does not match
# test if the linear patch can catch exception when dimension does not match
data = torch.rand(2, 2, device='meta')
_assert_output_shape(data, module, patched_module.torch_nn_linear, True, None)
def test_rnn():
# test rnn patch can produce the meta output with correct shape
data = (torch.randn(5, 3, 10), torch.randn(2, 3, 20))
module = torch.nn.RNN(10, 20, 2)
output, hn = module(*data)
meta_data = (torch.randn(5, 3, 10).to('meta'), torch.randn(2, 3, 20).to('meta'))
_assert_output_shape(meta_data, module, patched_module.torch_nn_rnn, False, (output.shape, hn.shape))
# test if the rnn patch can catch exception when dimension does not match
data = (torch.randn(5, 3, 10), torch.randn(2, 3, 20))
module = torch.nn.RNN(10, 20, 2)
output, hn = module(*data)
meta_data = (torch.randn(5, 3, 1).to('meta'), torch.randn(2, 3, 20).to('meta'))
_assert_output_shape(meta_data, module, patched_module.torch_nn_rnn, True, None)
def test_embedding():
data = torch.rand(2, 4, device='meta')
@@ -146,7 +172,7 @@ def test_conv1d():
def test_conv2d():
# test conv 1d
# test conv 2d
data = torch.rand(2, 3, 4, 4)
conv2d = torch.nn.Conv2d(in_channels=3, out_channels=4, kernel_size=2)
materialized_output = conv2d(data)
@@ -187,7 +213,7 @@ def test_conv2d():
def test_conv3d():
# test conv 1d
# test conv 3d
data = torch.rand(2, 3, 4, 4, 4)
conv3d = torch.nn.Conv3d(in_channels=3, out_channels=4, kernel_size=2)
materialized_output = conv3d(data)
@@ -227,6 +253,75 @@ def test_conv3d():
output_shape=materialized_output.shape)
def test_conv_transpose1d():
# test conv transpose1d
data = torch.rand(2, 3, 4)
convtrans1d = torch.nn.ConvTranspose1d(in_channels=3, out_channels=4, kernel_size=2)
materialized_output = convtrans1d(data)
meta_data = data.to('meta')
_assert_output_shape(data=meta_data,
module=convtrans1d,
patch_fn=patched_module.torch_nn_convtranspose1d,
expect_exception=False,
output_shape=materialized_output.shape)
convtrans1d = torch.nn.ConvTranspose1d(in_channels=3, out_channels=4, kernel_size=2, padding=1)
materialized_output = convtrans1d(data)
meta_data = data.to('meta')
_assert_output_shape(data=meta_data,
module=convtrans1d,
patch_fn=patched_module.torch_nn_convtranspose1d,
expect_exception=False,
output_shape=materialized_output.shape)
def test_conv_transpose2d():
# test conv transpose2d
data = torch.rand(2, 3, 4, 4)
convtrans2d = torch.nn.ConvTranspose2d(in_channels=3, out_channels=4, kernel_size=2)
materialized_output = convtrans2d(data)
meta_data = data.to('meta')
_assert_output_shape(data=meta_data,
module=convtrans2d,
patch_fn=patched_module.torch_nn_convtranspose2d,
expect_exception=False,
output_shape=materialized_output.shape)
convtrans2d = torch.nn.ConvTranspose2d(in_channels=3, out_channels=4, kernel_size=2, padding=1)
materialized_output = convtrans2d(data)
meta_data = data.to('meta')
_assert_output_shape(data=meta_data,
module=convtrans2d,
patch_fn=patched_module.torch_nn_convtranspose2d,
expect_exception=False,
output_shape=materialized_output.shape)
def test_conv_transpose3d():
# test conv transpose2d
data = torch.rand(2, 3, 4, 4, 4)
convtrans3d = torch.nn.ConvTranspose3d(in_channels=3, out_channels=4, kernel_size=2)
materialized_output = convtrans3d(data)
meta_data = data.to('meta')
_assert_output_shape(data=meta_data,
module=convtrans3d,
patch_fn=patched_module.torch_nn_convtranspose3d,
expect_exception=False,
output_shape=materialized_output.shape)
convtrans3d = torch.nn.ConvTranspose3d(in_channels=3, out_channels=4, kernel_size=2, padding=1)
materialized_output = convtrans3d(data)
meta_data = data.to('meta')
_assert_output_shape(data=meta_data,
module=convtrans3d,
patch_fn=patched_module.torch_nn_convtranspose3d,
expect_exception=False,
output_shape=materialized_output.shape)
def test_pool1d():
combinations = [[torch.nn.MaxPool1d, patched_module.torch_nn_maxpool1d],
[torch.nn.AvgPool1d, patched_module.torch_nn_avgpool1d]]