From be229217ce963ace870795ee281c98e438b5130b Mon Sep 17 00:00:00 2001 From: Super Daniel <78588128+super-dainiu@users.noreply.github.com> Date: Wed, 27 Jul 2022 11:03:14 +0800 Subject: [PATCH] [fx] add torchaudio test (#1369) * [fx]add torchaudio test * [fx]add torchaudio test * [fx] add torchaudio test * [fx] add torchaudio test * [fx] add torchaudio test * [fx] add torchaudio test * [fx] add torchaudio test * [fx] add torchaudio test and test patches * Delete ~ * [fx] add patches and patches test * [fx] add patches and patches test * [fx] fix patches * [fx] fix rnn patches * [fx] fix rnn patches * [fx] fix rnn patches * [fx] fix rnn patches * [fx] merge upstream * [fx] fix import errors --- .../meta_patch/patched_function/torch_ops.py | 21 +++ .../meta_patch/patched_module/__init__.py | 3 +- .../patched_module/activation_function.py | 1 + .../meta_patch/patched_module/convolution.py | 57 +++++++ .../meta_patch/patched_module/linear.py | 2 +- .../tracer/meta_patch/patched_module/rnn.py | 14 ++ .../utils/checkpoint/module_checkpoint.py | 2 +- colossalai/utils/checkpoint/utils.py | 7 +- requirements/requirements-test.txt | 1 + setup.py | 4 +- .../test_tracer/test_patched_module.py | 107 ++++++++++++- tests/test_fx/test_tracer/test_patched_op.py | 63 ++++++++ .../test_timm_model/test_timm_model.py | 2 +- .../test_torchaudio_general.py | 145 ++++++++++++++++++ .../test_torchaudio_tacotron.py | 57 +++++++ .../test_torchaudio_transformer.py | 61 ++++++++ .../test_torchaudio_wave2vec.py | 50 ++++++ .../test_torchaudio_model/torchaudio_utils.py | 28 ++++ 18 files changed, 609 insertions(+), 16 deletions(-) create mode 100644 colossalai/fx/tracer/meta_patch/patched_module/rnn.py create mode 100644 tests/test_fx/test_tracer/test_patched_op.py create mode 100644 tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_general.py create mode 100644 tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_tacotron.py create mode 100644 tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_transformer.py create mode 100644 tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_wave2vec.py create mode 100644 tests/test_fx/test_tracer/test_torchaudio_model/torchaudio_utils.py diff --git a/colossalai/fx/tracer/meta_patch/patched_function/torch_ops.py b/colossalai/fx/tracer/meta_patch/patched_function/torch_ops.py index 4c5c7c2e3..2ee5cb112 100644 --- a/colossalai/fx/tracer/meta_patch/patched_function/torch_ops.py +++ b/colossalai/fx/tracer/meta_patch/patched_function/torch_ops.py @@ -108,6 +108,27 @@ def torch_cat(tensors, dim=None, axis=None, *, out=None): return torch.empty(final_shape, device="meta") +@meta_patched_function.register(torch.repeat_interleave) +def torch_repeat_interleave(input, repeats, dim=None, output_size=None): + assert isinstance(repeats, int) or isinstance(repeats, torch.Tensor), \ + "Argument 'repeats' should be of type 'torch.Tensor' or 'int'" + + shape = list(input.shape) if dim is not None else [input.numel()] + dim = dim if dim is not None else 0 + dim = input.dim() + dim if dim < 0 else dim + + if isinstance(repeats, int): + shape[dim] = shape[dim] * repeats + elif isinstance(repeats, torch.Tensor): + shape[dim] = repeats.sum() + return torch.empty(shape, device="meta") + + +@meta_patched_function.register(torch.Tensor.repeat_interleave) +def torch_tensor_repeat_interleave(self, repeats, dim=None, *, output_size=None): + return torch_repeat_interleave(self, repeats, dim, output_size) + + @meta_patched_function.register(torch.roll) def torch_roll(input, shifts, dims=None): return torch.empty(input.shape, device='meta') diff --git a/colossalai/fx/tracer/meta_patch/patched_module/__init__.py b/colossalai/fx/tracer/meta_patch/patched_module/__init__.py index bd550487c..e28e52585 100644 --- a/colossalai/fx/tracer/meta_patch/patched_module/__init__.py +++ b/colossalai/fx/tracer/meta_patch/patched_module/__init__.py @@ -3,4 +3,5 @@ from .convolution import * from .embedding import * from .linear import * from .normalization import * -from .pooling import * \ No newline at end of file +from .pooling import * +from .rnn import * \ No newline at end of file diff --git a/colossalai/fx/tracer/meta_patch/patched_module/activation_function.py b/colossalai/fx/tracer/meta_patch/patched_module/activation_function.py index ed2f4bcaf..ed572e3b7 100644 --- a/colossalai/fx/tracer/meta_patch/patched_module/activation_function.py +++ b/colossalai/fx/tracer/meta_patch/patched_module/activation_function.py @@ -7,5 +7,6 @@ from ..registry import meta_patched_module @meta_patched_module.register(torch.nn.GELU) @meta_patched_module.register(torch.nn.Tanh) @meta_patched_module.register(torch.nn.ReLU6) +@meta_patched_module.register(torch.nn.PReLU) def torch_nn_non_linear_act(self, input): return torch.empty(input.shape, device='meta') diff --git a/colossalai/fx/tracer/meta_patch/patched_module/convolution.py b/colossalai/fx/tracer/meta_patch/patched_module/convolution.py index b600f4df2..327450cb2 100644 --- a/colossalai/fx/tracer/meta_patch/patched_module/convolution.py +++ b/colossalai/fx/tracer/meta_patch/patched_module/convolution.py @@ -55,3 +55,60 @@ def torch_nn_conv3d(self, input): w_out, ) return torch.empty(result_shape, device='meta') + +@meta_patched_module.register(torch.nn.ConvTranspose1d) +def torch_nn_convtranspose1d(self, input): + # the output shape is calculated using the formula stated + # at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose1d.html + l_in = input.shape[-1] + c_out = self.out_channels + l_out = math.floor((l_in - 1) * self.stride[0] - 2 * self.padding[0] + + self.dilation[0] * (self.kernel_size[0] - 1) + + self.output_padding[0] + 1) + result_shape = input.shape[:-2] + ( + c_out, + l_out, + ) + return torch.empty(result_shape, device='meta') + +@meta_patched_module.register(torch.nn.ConvTranspose2d) +def torch_nn_convtranspose2d(self, input): + # the output shape is calculated using the formula stated + # at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html + h_in, w_in = input.shape[-2:] + c_out = self.out_channels + h_out = math.floor((h_in - 1) * self.stride[0] - 2 * self.padding[0] + + self.dilation[0] * (self.kernel_size[0] - 1) + + self.output_padding[0] + 1) + w_out = math.floor((w_in - 1) * self.stride[1] - 2 * self.padding[1] + + self.dilation[1] * (self.kernel_size[1] - 1) + + self.output_padding[1] + 1) + result_shape = input.shape[:-3] + ( + c_out, + h_out, + w_out, + ) + return torch.empty(result_shape, device='meta') + +@meta_patched_module.register(torch.nn.ConvTranspose3d) +def torch_nn_convtranspose3d(self, input): + # the output shape is calculated using the formula stated + # at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose3d.html + d_in, h_in, w_in = input.shape[-3:] + c_out = self.out_channels + d_out = math.floor((d_in - 1) * self.stride[0] - 2 * self.padding[0] + + self.dilation[0] * (self.kernel_size[0] - 1) + + self.output_padding[0] + 1) + h_out = math.floor((h_in - 1) * self.stride[1] - 2 * self.padding[1] + + self.dilation[1] * (self.kernel_size[1] - 1) + + self.output_padding[1] + 1) + w_out = math.floor((w_in - 1) * self.stride[2] - 2 * self.padding[2] + + self.dilation[2] * (self.kernel_size[2] - 1) + + self.output_padding[2] + 1) + result_shape = input.shape[:-4] + ( + c_out, + d_out, + h_out, + w_out, + ) + return torch.empty(result_shape, device='meta') \ No newline at end of file diff --git a/colossalai/fx/tracer/meta_patch/patched_module/linear.py b/colossalai/fx/tracer/meta_patch/patched_module/linear.py index 1f22ffd60..0275f134d 100644 --- a/colossalai/fx/tracer/meta_patch/patched_module/linear.py +++ b/colossalai/fx/tracer/meta_patch/patched_module/linear.py @@ -6,4 +6,4 @@ from ..registry import meta_patched_module def torch_nn_linear(self, input): last_dim = input.shape[-1] assert last_dim == self.in_features, f'Expected hidden size {self.in_features} but got {last_dim} for the torch.nn.Linear patch' - return torch.empty(input.shape[:-1] + (self.out_features,), device="meta") \ No newline at end of file + return torch.empty(input.shape[:-1] + (self.out_features,), device="meta") diff --git a/colossalai/fx/tracer/meta_patch/patched_module/rnn.py b/colossalai/fx/tracer/meta_patch/patched_module/rnn.py new file mode 100644 index 000000000..15a0be417 --- /dev/null +++ b/colossalai/fx/tracer/meta_patch/patched_module/rnn.py @@ -0,0 +1,14 @@ +import torch +from ..registry import meta_patched_module +from typing import Optional + + +@meta_patched_module.register(torch.nn.GRU) +@meta_patched_module.register(torch.nn.RNN) +def torch_nn_rnn(self, input, hx): + assert input.shape[ + -1] == self.input_size, f'Expected input to have input size {self.input_size} but got {input.shape[-1]} for the torch.nn.RNN patch' + assert hx.shape[ + -1] == self.hidden_size, f'Expected hx to have hidden size {self.hidden_size} but got {hx.shape[-1]} for the torch.nn.RNN patch' + d = 2 if self.bidirectional else 1 + return torch.empty(input.shape[:-1] + (self.hidden_size * d,), device="meta"), hx diff --git a/colossalai/utils/checkpoint/module_checkpoint.py b/colossalai/utils/checkpoint/module_checkpoint.py index cf9b11cc6..43c518394 100644 --- a/colossalai/utils/checkpoint/module_checkpoint.py +++ b/colossalai/utils/checkpoint/module_checkpoint.py @@ -27,7 +27,7 @@ def save_checkpoint(dire: str, # save the dist context about the tensors in a new dict, while still maintain the original dict. for k, v in model_state.items(): if isinstance(v, ColoTensor): - gather_tensor(v) # gather shared tensors to rank0 + gather_tensor(v) # gather shared tensors to rank0 # don't recover tensors in rank0, since the dict is only a copy of model if rank == 0: diff --git a/colossalai/utils/checkpoint/utils.py b/colossalai/utils/checkpoint/utils.py index a9e0e7edd..cd6f85175 100644 --- a/colossalai/utils/checkpoint/utils.py +++ b/colossalai/utils/checkpoint/utils.py @@ -34,7 +34,7 @@ def gather_tensor(colo_tensor: ColoTensor) -> None: dist.barrier() if dist.get_rank() == 0: - setattr(colo_tensor, 'save_ready', True) # set saving signitrue + setattr(colo_tensor, 'save_ready', True) # set saving signitrue def scatter_tensor(colo_tensor: ColoTensor, dist_spec: _DistSpec) -> None: @@ -54,9 +54,8 @@ def scatter_tensor(colo_tensor: ColoTensor, dist_spec: _DistSpec) -> None: if dist.get_rank() == 0: colo_tensor.set_dist_spec(dist_spec) else: - rep_tensor = ColoTensor(entire_data, ColoTensorSpec( - pg=colo_tensor.get_process_group(), - compute_attr=colo_tensor.compute_spec)) + rep_tensor = ColoTensor( + entire_data, ColoTensorSpec(pg=colo_tensor.get_process_group(), compute_attr=colo_tensor.compute_spec)) rep_tensor.set_dist_spec(dist_spec) with torch.no_grad(): colo_tensor.data.copy_(rep_tensor.data) diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt index 8e4e5268d..e69bcc244 100644 --- a/requirements/requirements-test.txt +++ b/requirements/requirements-test.txt @@ -3,4 +3,5 @@ torchvision transformers timm titans +torchaudio torchrec diff --git a/setup.py b/setup.py index 10906b61c..b58a4d989 100644 --- a/setup.py +++ b/setup.py @@ -100,7 +100,7 @@ def get_version(): version += f'+torch{torch_version}cu{cuda_version}' return version - + if build_cuda_ext: try: import torch @@ -115,7 +115,7 @@ if build_cuda_ext: except ImportError: print('torch is not found. CUDA extension will not be installed') build_cuda_ext = False - + if build_cuda_ext: build_cuda_ext = check_cuda_availability(CUDA_HOME) and check_cuda_torch_binary_vs_bare_metal(CUDA_HOME) diff --git a/tests/test_fx/test_tracer/test_patched_module.py b/tests/test_fx/test_tracer/test_patched_module.py index 9b4f7c516..82389ae41 100644 --- a/tests/test_fx/test_tracer/test_patched_module.py +++ b/tests/test_fx/test_tracer/test_patched_module.py @@ -4,7 +4,12 @@ from colossalai.fx.tracer.meta_patch import patched_module def _run(data, module, patch_fn): try: - output = patch_fn(module, data) + if isinstance(data, dict): + output = patch_fn(module, **data) + if isinstance(data, tuple) or isinstance(data, list): + output = patch_fn(module, *data) + else: + output = patch_fn(module, data) return output except Exception as e: return e @@ -17,8 +22,13 @@ def _assert_output_shape(data, module, patch_fn, expect_exception, output_shape) assert isinstance(output, AssertionError) else: assert not isinstance(output, Exception) - assert output.is_meta - assert output.shape == output_shape + if isinstance(output, tuple): + for item, shape in zip(output, output_shape): + assert item.is_meta + assert item.shape == shape + else: + assert output.is_meta + assert output.shape == output_shape def test_linear(): @@ -27,11 +37,27 @@ def test_linear(): module = torch.nn.Linear(4, 2) _assert_output_shape(data, module, patched_module.torch_nn_linear, False, torch.Size([2, 2])) - # Test if the linear patch can catch exception when dimension does not match + # test if the linear patch can catch exception when dimension does not match data = torch.rand(2, 2, device='meta') _assert_output_shape(data, module, patched_module.torch_nn_linear, True, None) +def test_rnn(): + # test rnn patch can produce the meta output with correct shape + data = (torch.randn(5, 3, 10), torch.randn(2, 3, 20)) + module = torch.nn.RNN(10, 20, 2) + output, hn = module(*data) + meta_data = (torch.randn(5, 3, 10).to('meta'), torch.randn(2, 3, 20).to('meta')) + _assert_output_shape(meta_data, module, patched_module.torch_nn_rnn, False, (output.shape, hn.shape)) + + # test if the rnn patch can catch exception when dimension does not match + data = (torch.randn(5, 3, 10), torch.randn(2, 3, 20)) + module = torch.nn.RNN(10, 20, 2) + output, hn = module(*data) + meta_data = (torch.randn(5, 3, 1).to('meta'), torch.randn(2, 3, 20).to('meta')) + _assert_output_shape(meta_data, module, patched_module.torch_nn_rnn, True, None) + + def test_embedding(): data = torch.rand(2, 4, device='meta') @@ -146,7 +172,7 @@ def test_conv1d(): def test_conv2d(): - # test conv 1d + # test conv 2d data = torch.rand(2, 3, 4, 4) conv2d = torch.nn.Conv2d(in_channels=3, out_channels=4, kernel_size=2) materialized_output = conv2d(data) @@ -187,7 +213,7 @@ def test_conv2d(): def test_conv3d(): - # test conv 1d + # test conv 3d data = torch.rand(2, 3, 4, 4, 4) conv3d = torch.nn.Conv3d(in_channels=3, out_channels=4, kernel_size=2) materialized_output = conv3d(data) @@ -227,6 +253,75 @@ def test_conv3d(): output_shape=materialized_output.shape) +def test_conv_transpose1d(): + # test conv transpose1d + data = torch.rand(2, 3, 4) + + convtrans1d = torch.nn.ConvTranspose1d(in_channels=3, out_channels=4, kernel_size=2) + materialized_output = convtrans1d(data) + meta_data = data.to('meta') + _assert_output_shape(data=meta_data, + module=convtrans1d, + patch_fn=patched_module.torch_nn_convtranspose1d, + expect_exception=False, + output_shape=materialized_output.shape) + + convtrans1d = torch.nn.ConvTranspose1d(in_channels=3, out_channels=4, kernel_size=2, padding=1) + materialized_output = convtrans1d(data) + meta_data = data.to('meta') + _assert_output_shape(data=meta_data, + module=convtrans1d, + patch_fn=patched_module.torch_nn_convtranspose1d, + expect_exception=False, + output_shape=materialized_output.shape) + + +def test_conv_transpose2d(): + # test conv transpose2d + data = torch.rand(2, 3, 4, 4) + + convtrans2d = torch.nn.ConvTranspose2d(in_channels=3, out_channels=4, kernel_size=2) + materialized_output = convtrans2d(data) + meta_data = data.to('meta') + _assert_output_shape(data=meta_data, + module=convtrans2d, + patch_fn=patched_module.torch_nn_convtranspose2d, + expect_exception=False, + output_shape=materialized_output.shape) + + convtrans2d = torch.nn.ConvTranspose2d(in_channels=3, out_channels=4, kernel_size=2, padding=1) + materialized_output = convtrans2d(data) + meta_data = data.to('meta') + _assert_output_shape(data=meta_data, + module=convtrans2d, + patch_fn=patched_module.torch_nn_convtranspose2d, + expect_exception=False, + output_shape=materialized_output.shape) + + +def test_conv_transpose3d(): + # test conv transpose2d + data = torch.rand(2, 3, 4, 4, 4) + + convtrans3d = torch.nn.ConvTranspose3d(in_channels=3, out_channels=4, kernel_size=2) + materialized_output = convtrans3d(data) + meta_data = data.to('meta') + _assert_output_shape(data=meta_data, + module=convtrans3d, + patch_fn=patched_module.torch_nn_convtranspose3d, + expect_exception=False, + output_shape=materialized_output.shape) + + convtrans3d = torch.nn.ConvTranspose3d(in_channels=3, out_channels=4, kernel_size=2, padding=1) + materialized_output = convtrans3d(data) + meta_data = data.to('meta') + _assert_output_shape(data=meta_data, + module=convtrans3d, + patch_fn=patched_module.torch_nn_convtranspose3d, + expect_exception=False, + output_shape=materialized_output.shape) + + def test_pool1d(): combinations = [[torch.nn.MaxPool1d, patched_module.torch_nn_maxpool1d], [torch.nn.AvgPool1d, patched_module.torch_nn_avgpool1d]] diff --git a/tests/test_fx/test_tracer/test_patched_op.py b/tests/test_fx/test_tracer/test_patched_op.py new file mode 100644 index 000000000..05c29b824 --- /dev/null +++ b/tests/test_fx/test_tracer/test_patched_op.py @@ -0,0 +1,63 @@ +import torch +from colossalai.fx.tracer.meta_patch import patched_function +from functools import partial + + +def _run(data, patch_fn): + try: + output = patch_fn(data) + return output + except Exception as e: + return e + + +def _assert_output_shape(data, patch_fn, expect_exception, output_shape): + output = _run(data, patch_fn) + + if expect_exception: + assert isinstance(output, AssertionError) + else: + assert not isinstance(output, Exception) + assert output.is_meta + assert output.shape == output_shape + + +def test_repeat_interleave(): + patch_fn = patched_function.torch_repeat_interleave + + # examples from https://pytorch.org/docs/stable/generated/torch.repeat_interleave.html + data = torch.tensor([1, 2, 3]) + materialized_output = torch.repeat_interleave(data, repeats=2) + repeat_interleave = partial(patch_fn, repeats=2) + meta_data = data.to('meta') + _assert_output_shape(data=meta_data, + patch_fn=repeat_interleave, + expect_exception=False, + output_shape=materialized_output.shape) + + data = torch.tensor([[1, 2], [3, 4]]) + materialized_output = torch.repeat_interleave(data, repeats=3, dim=1) + repeat_interleave = partial(patch_fn, repeats=3, dim=1) + meta_data = data.to('meta') + _assert_output_shape(data=meta_data, + patch_fn=repeat_interleave, + expect_exception=False, + output_shape=materialized_output.shape) + + data = torch.tensor([[1, 2], [3, 4]]) + materialized_output = torch.repeat_interleave(data, repeats=torch.tensor([1, 2]), dim=-1) + repeat_interleave = partial(patch_fn, repeats=torch.tensor([1, 2]), dim=-1) + meta_data = data.to('meta') + _assert_output_shape(data=meta_data, + patch_fn=repeat_interleave, + expect_exception=False, + output_shape=materialized_output.shape) + + data = torch.tensor([[1, 2], [3, 4]]) + materialized_output = torch.repeat_interleave(data, repeats=torch.tensor([1, 2]), dim=0) + repeat_interleave = partial(patch_fn, repeats=[1, 2], dim=0) + meta_data = data.to('meta') + _assert_output_shape(data=meta_data, + patch_fn=repeat_interleave, + expect_exception=True, + output_shape=materialized_output.shape) diff --git a/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py b/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py index 2ee498b9e..1ce679d4c 100644 --- a/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py +++ b/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py @@ -22,7 +22,7 @@ def trace_and_compare(model_cls, tracer, data, meta_args=None): with torch.no_grad(): fx_out = gm(data) non_fx_out = model(data) - + # compare output if isinstance(fx_out, tuple): # some models produce tuple as output diff --git a/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_general.py b/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_general.py new file mode 100644 index 000000000..b2fa8c6c0 --- /dev/null +++ b/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_general.py @@ -0,0 +1,145 @@ +import torch +from torchaudio_utils import trace_and_compare +from torchaudio.models import ConvTasNet, DeepSpeech, Wav2Letter, WaveRNN +from torchaudio.models.wavernn import MelResNet, UpsampleNetwork +import pytest + + +def test_wave2letter_waveform(): + batch_size = 2 + num_features = 1 + num_classes = 40 + input_length = 320 + + model = Wav2Letter(num_classes=num_classes, num_features=num_features) + + def data_gen(): + x = torch.rand(batch_size, num_features, input_length) + return dict(x=x) + + trace_and_compare(model, data_gen, need_meta=False, need_concrete=False) + + +def test_wave2letter_mfcc(): + batch_size = 2 + num_features = 13 + num_classes = 40 + input_length = 2 + + model = Wav2Letter(num_classes=num_classes, input_type="mfcc", num_features=num_features) + + def data_gen(): + x = torch.rand(batch_size, num_features, input_length) + return dict(x=x) + + trace_and_compare(model, data_gen, need_meta=False, need_concrete=False) + + +def test_melresnet_waveform(): + n_batch = 2 + n_time = 200 + n_freq = 100 + n_output = 128 + n_res_block = 10 + n_hidden = 128 + kernel_size = 5 + + model = MelResNet(n_res_block, n_freq, n_hidden, n_output, kernel_size) + + def data_gen(): + x = torch.rand(n_batch, n_freq, n_time) + return dict(specgram=x) + + trace_and_compare(model, data_gen, need_meta=False, need_concrete=False) + + +def test_upsample_network_waveform(): + upsample_scales = [5, 5, 8] + n_batch = 2 + n_time = 200 + n_freq = 100 + n_output = 64 + n_res_block = 10 + n_hidden = 32 + kernel_size = 5 + + total_scale = 1 + for upsample_scale in upsample_scales: + total_scale *= upsample_scale + + model = UpsampleNetwork(upsample_scales, n_res_block, n_freq, n_hidden, n_output, kernel_size) + + def data_gen(): + x = torch.rand(n_batch, n_freq, n_time) + return dict(specgram=x) + + trace_and_compare(model, data_gen, need_meta=False, need_concrete=False) + + +def test_wavernn_waveform(): + upsample_scales = [2, 2, 5] + n_rnn = 16 + n_fc = 16 + n_classes = 10 + hop_length = 20 + n_batch = 2 + n_time = 20 + n_freq = 10 + n_output = 16 + n_res_block = 3 + n_hidden = 16 + kernel_size = 5 + + model = WaveRNN(upsample_scales, n_classes, hop_length, n_res_block, n_rnn, n_fc, kernel_size, n_freq, n_hidden, + n_output) + + def data_gen(): + x = torch.rand(n_batch, 1, hop_length * (n_time - kernel_size + 1)) + mels = torch.rand(n_batch, 1, n_freq, n_time) + return dict(waveform=x, specgram=mels) + + trace_and_compare(model, data_gen, need_meta=True, need_concrete=False) + + +def test_convtasnet_config(): + batch_size = 32 + num_frames = 800 + + model = ConvTasNet() + + def data_gen(): + tensor = torch.rand(batch_size, 1, num_frames) + return dict(input=tensor) + + trace_and_compare(model, data_gen, need_meta=True, need_concrete=False) + + +def test_deepspeech(): + n_batch = 2 + n_feature = 1 + n_channel = 1 + n_class = 40 + n_time = 32 + + model = DeepSpeech(n_feature=n_feature, n_class=n_class) + + def data_gen(): + x = torch.rand(n_batch, n_channel, n_time, n_feature) + return dict(x=x) + + trace_and_compare(model, data_gen, need_meta=False, need_concrete=False) + + +if __name__ == '__main__': + TEST_LIST = [ + test_wave2letter_waveform, + test_wave2letter_mfcc, + test_melresnet_waveform, + test_upsample_network_waveform, + test_wavernn_waveform, + test_convtasnet_config, + test_deepspeech, + ] + + for test_fn in TEST_LIST: + test_fn() diff --git a/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_tacotron.py b/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_tacotron.py new file mode 100644 index 000000000..165ac6bb0 --- /dev/null +++ b/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_tacotron.py @@ -0,0 +1,57 @@ +import torch +from torchaudio.models import Tacotron2 +from torchaudio_utils import trace_and_compare +import pytest + + +def _get_tacotron2_model(n_mels, decoder_max_step=2000, gate_threshold=0.5): + return Tacotron2( + mask_padding=False, + n_mels=n_mels, + n_symbol=20, + n_frames_per_step=1, + symbol_embedding_dim=32, + encoder_embedding_dim=32, + encoder_n_convolution=3, + encoder_kernel_size=5, + decoder_rnn_dim=32, + decoder_max_step=decoder_max_step, + decoder_dropout=0.1, + decoder_early_stopping=True, + attention_rnn_dim=32, + attention_hidden_dim=32, + attention_location_n_filter=32, + attention_location_kernel_size=31, + attention_dropout=0.1, + prenet_dim=32, + postnet_n_convolution=5, + postnet_kernel_size=5, + postnet_embedding_dim=512, + gate_threshold=gate_threshold, + ) + + +@pytest.mark.skip +def test_tacotron_model(): + n_mels = 80 + n_batch = 3 + max_mel_specgram_length = 300 + max_text_length = 100 + + model = _get_tacotron2_model(n_mels) + + def data_gen(): + text = torch.randint(0, 148, (n_batch, max_text_length)) + text_lengths = max_text_length * torch.ones((n_batch,)) + mel_specgram = torch.rand(n_batch, n_mels, max_mel_specgram_length) + mel_specgram_lengths = max_mel_specgram_length * torch.ones((n_batch,)) + return dict(tokens=text, + token_lengths=text_lengths, + mel_specgram=mel_specgram, + mel_specgram_lengths=mel_specgram_lengths) + + trace_and_compare(model, data_gen, need_meta=True, need_concrete=False) + + +if __name__ == "__main__": + test_tacotron_model() diff --git a/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_transformer.py b/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_transformer.py new file mode 100644 index 000000000..fb473039b --- /dev/null +++ b/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_transformer.py @@ -0,0 +1,61 @@ +import torch +from torchaudio_utils import trace_and_compare +from torchaudio.models import Emformer, Conformer +import pytest + + +@pytest.mark.skip +def test_conformer(): + input_dim = 80 + batch_size = 10 + num_frames = 400 + num_heads = 4 + ffn_dim = 128 + num_layers = 4 + depthwise_conv_kernel_size = 31 + + model = Conformer( + input_dim=input_dim, + num_heads=num_heads, + ffn_dim=ffn_dim, + num_layers=num_layers, + depthwise_conv_kernel_size=depthwise_conv_kernel_size, + ) + + def data_gen(): + lengths = torch.randint(1, num_frames, (batch_size,)) + input = torch.rand(batch_size, int(lengths.max()), input_dim) + return dict(input=input, lengths=lengths) + + trace_and_compare(model, data_gen, need_meta=False, need_concrete=True) + + +@pytest.mark.skip +def test_emformer(): + input_dim = 128 + batch_size = 10 + num_heads = 8 + ffn_dim = 256 + num_layers = 3 + segment_length = 4 + num_frames = 400 + right_context_length = 1 + + model = Emformer(input_dim, num_heads, ffn_dim, num_layers, segment_length, right_context_length) + + def data_gen(): + lengths = torch.randint(1, num_frames, (batch_size,)) + input = torch.rand(batch_size, num_frames, input_dim) + return dict(input=input, lengths=lengths) + + trace_and_compare(model, data_gen, need_meta=True, need_concrete=False) + + +@pytest.mark.skip +def test_torchaudio_transformers(): + test_conformer() + test_emformer() + + +if __name__ == "__main__": + test_torchaudio_transformers() diff --git a/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_wave2vec.py b/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_wave2vec.py new file mode 100644 index 000000000..fe25ab97f --- /dev/null +++ b/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_wave2vec.py @@ -0,0 +1,50 @@ +import torch +from torchaudio.models.wav2vec2 import ( + hubert_base, + hubert_large, + hubert_xlarge, + wav2vec2_base, + wav2vec2_large, + wav2vec2_large_lv60k, +) +from torchaudio_utils import trace_and_compare +import pytest + +MODEL_LIST = [ + hubert_base, + hubert_large, + hubert_xlarge, + wav2vec2_base, + wav2vec2_large, + wav2vec2_large_lv60k, +] + + +def _smoke_test(model, device): + model = model.to(device=device) + + batch_size, num_frames = 3, 1024 + + def data_gen(): + waveforms = torch.randn(batch_size, num_frames, device=device) + lengths = torch.randint( + low=0, + high=num_frames, + size=[ + batch_size, + ], + device=device, + ) + return dict(waveforms=waveforms, lengths=lengths) + + trace_and_compare(model, data_gen, need_meta=True, need_concrete=False) + + +@pytest.mark.skip +def test_wav2vec(): + for model_fn in MODEL_LIST: + _smoke_test(model_fn(), 'cpu') + + +if __name__ == "__main__": + test_wav2vec() diff --git a/tests/test_fx/test_tracer/test_torchaudio_model/torchaudio_utils.py b/tests/test_fx/test_tracer/test_torchaudio_model/torchaudio_utils.py new file mode 100644 index 000000000..cee555df3 --- /dev/null +++ b/tests/test_fx/test_tracer/test_torchaudio_model/torchaudio_utils.py @@ -0,0 +1,28 @@ +from colossalai.fx import ColoTracer +import torch +from torch.fx import GraphModule, Tracer + + +def trace_and_compare(model, data_gen, need_meta=False, need_concrete=False): + data = data_gen() + concrete_args = data if need_concrete else {} + meta_args = {k: v.to('meta') for k, v in data.items()} if need_meta else {} + tracer = ColoTracer() + + graph = tracer.trace(root=model, concrete_args=concrete_args, meta_args=meta_args) + gm = GraphModule(model, graph, model.__class__.__name__) + gm.recompile() + + model.eval() + gm.eval() + + with torch.no_grad(): + non_fx_out = model(**data) + fx_out = gm(**data) + if isinstance(fx_out, tuple): + for non_fx, fx in zip(non_fx_out, fx_out): + assert torch.allclose(non_fx, + fx), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}' + else: + assert torch.allclose( + fx_out, non_fx_out), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'