[tests] model zoo add torchaudio models (#3138)

* [tests] model zoo add torchaudio models

* [tests] refactor torchaudio wavernn

* [tests] refactor fx torchaudio tests
This commit is contained in:
ver217
2023-03-15 11:51:16 +08:00
committed by GitHub
parent 6d48eb0560
commit 14a115000b
9 changed files with 166 additions and 333 deletions

View File

@@ -1,145 +0,0 @@
import torch
from torchaudio_utils import trace_and_compare
from torchaudio.models import ConvTasNet, DeepSpeech, Wav2Letter, WaveRNN
from torchaudio.models.wavernn import MelResNet, UpsampleNetwork
import pytest
def test_wave2letter_waveform():
batch_size = 2
num_features = 1
num_classes = 40
input_length = 320
model = Wav2Letter(num_classes=num_classes, num_features=num_features)
def data_gen():
x = torch.rand(batch_size, num_features, input_length)
return dict(x=x)
trace_and_compare(model, data_gen, need_meta=False, need_concrete=False)
def test_wave2letter_mfcc():
batch_size = 2
num_features = 13
num_classes = 40
input_length = 2
model = Wav2Letter(num_classes=num_classes, input_type="mfcc", num_features=num_features)
def data_gen():
x = torch.rand(batch_size, num_features, input_length)
return dict(x=x)
trace_and_compare(model, data_gen, need_meta=False, need_concrete=False)
def test_melresnet_waveform():
n_batch = 2
n_time = 200
n_freq = 100
n_output = 128
n_res_block = 10
n_hidden = 128
kernel_size = 5
model = MelResNet(n_res_block, n_freq, n_hidden, n_output, kernel_size)
def data_gen():
x = torch.rand(n_batch, n_freq, n_time)
return dict(specgram=x)
trace_and_compare(model, data_gen, need_meta=False, need_concrete=False)
def test_upsample_network_waveform():
upsample_scales = [5, 5, 8]
n_batch = 2
n_time = 200
n_freq = 100
n_output = 64
n_res_block = 10
n_hidden = 32
kernel_size = 5
total_scale = 1
for upsample_scale in upsample_scales:
total_scale *= upsample_scale
model = UpsampleNetwork(upsample_scales, n_res_block, n_freq, n_hidden, n_output, kernel_size)
def data_gen():
x = torch.rand(n_batch, n_freq, n_time)
return dict(specgram=x)
trace_and_compare(model, data_gen, need_meta=False, need_concrete=False)
def test_wavernn_waveform():
upsample_scales = [2, 2, 5]
n_rnn = 16
n_fc = 16
n_classes = 10
hop_length = 20
n_batch = 2
n_time = 20
n_freq = 10
n_output = 16
n_res_block = 3
n_hidden = 16
kernel_size = 5
model = WaveRNN(upsample_scales, n_classes, hop_length, n_res_block, n_rnn, n_fc, kernel_size, n_freq, n_hidden,
n_output)
def data_gen():
x = torch.rand(n_batch, 1, hop_length * (n_time - kernel_size + 1))
mels = torch.rand(n_batch, 1, n_freq, n_time)
return dict(waveform=x, specgram=mels)
trace_and_compare(model, data_gen, need_meta=True, need_concrete=False)
def test_convtasnet_config():
batch_size = 32
num_frames = 800
model = ConvTasNet()
def data_gen():
tensor = torch.rand(batch_size, 1, num_frames)
return dict(input=tensor)
trace_and_compare(model, data_gen, need_meta=True, need_concrete=False)
def test_deepspeech():
n_batch = 2
n_feature = 1
n_channel = 1
n_class = 40
n_time = 32
model = DeepSpeech(n_feature=n_feature, n_class=n_class)
def data_gen():
x = torch.rand(n_batch, n_channel, n_time, n_feature)
return dict(x=x)
trace_and_compare(model, data_gen, need_meta=False, need_concrete=False)
if __name__ == '__main__':
TEST_LIST = [
test_wave2letter_waveform,
test_wave2letter_mfcc,
test_melresnet_waveform,
test_upsample_network_waveform,
test_wavernn_waveform,
test_convtasnet_config,
test_deepspeech,
]
for test_fn in TEST_LIST:
test_fn()

View File

@@ -0,0 +1,22 @@
import re
import torch
from torchaudio_utils import trace_and_compare
from tests.kit.model_zoo import model_zoo
def test_torchaudio_models():
torch.backends.cudnn.deterministic = True
sub_model_zoo = model_zoo.get_sub_registry('torchaudio')
for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in sub_model_zoo.items():
# FIXME(ver217): temporarily skip these models
if re.search(f'(conformer|emformer|tacotron|wav2vec2_base|hubert_base)', name):
continue
model = model_fn()
trace_and_compare(model,
data_gen_fn,
output_transform_fn,
need_meta=(attribute is not None and attribute.has_control_flow))

View File

@@ -1,57 +0,0 @@
import torch
from torchaudio.models import Tacotron2
from torchaudio_utils import trace_and_compare
import pytest
def _get_tacotron2_model(n_mels, decoder_max_step=2000, gate_threshold=0.5):
return Tacotron2(
mask_padding=False,
n_mels=n_mels,
n_symbol=20,
n_frames_per_step=1,
symbol_embedding_dim=32,
encoder_embedding_dim=32,
encoder_n_convolution=3,
encoder_kernel_size=5,
decoder_rnn_dim=32,
decoder_max_step=decoder_max_step,
decoder_dropout=0.1,
decoder_early_stopping=True,
attention_rnn_dim=32,
attention_hidden_dim=32,
attention_location_n_filter=32,
attention_location_kernel_size=31,
attention_dropout=0.1,
prenet_dim=32,
postnet_n_convolution=5,
postnet_kernel_size=5,
postnet_embedding_dim=512,
gate_threshold=gate_threshold,
)
@pytest.mark.skip("Tracing failed")
def test_tacotron_model():
n_mels = 80
n_batch = 3
max_mel_specgram_length = 300
max_text_length = 100
model = _get_tacotron2_model(n_mels)
def data_gen():
text = torch.randint(0, 148, (n_batch, max_text_length))
text_lengths = max_text_length * torch.ones((n_batch,))
mel_specgram = torch.rand(n_batch, n_mels, max_mel_specgram_length)
mel_specgram_lengths = max_mel_specgram_length * torch.ones((n_batch,))
return dict(tokens=text,
token_lengths=text_lengths,
mel_specgram=mel_specgram,
mel_specgram_lengths=mel_specgram_lengths)
trace_and_compare(model, data_gen, need_meta=True, need_concrete=False)
if __name__ == "__main__":
test_tacotron_model()

View File

@@ -1,67 +0,0 @@
import torch
from torchaudio_utils import trace_and_compare
from torchaudio.models import Emformer, Conformer
import pytest
def test_conformer():
input_dim = 80
batch_size = 10
num_frames = 400
num_heads = 4
ffn_dim = 128
num_layers = 4
depthwise_conv_kernel_size = 31
model = Conformer(
input_dim=input_dim,
num_heads=num_heads,
ffn_dim=ffn_dim,
num_layers=num_layers,
depthwise_conv_kernel_size=depthwise_conv_kernel_size,
)
def data_gen():
lengths = torch.randint(1, num_frames, (batch_size,))
input = torch.rand(batch_size, int(lengths.max()), input_dim)
return dict(input=input, lengths=lengths)
def kwargs_transform(data):
new_data = {}
for k, v in data.items():
new_data[f'{k}_1'] = v
return new_data
trace_and_compare(model, data_gen, need_meta=False, need_concrete=True, kwargs_transform=kwargs_transform)
@pytest.mark.skip("Tracing failed")
def test_emformer():
input_dim = 128
batch_size = 10
num_heads = 8
ffn_dim = 256
num_layers = 3
segment_length = 4
num_frames = 400
right_context_length = 1
model = Emformer(input_dim, num_heads, ffn_dim, num_layers, segment_length, right_context_length)
def data_gen():
lengths = torch.randint(1, num_frames, (batch_size,))
input = torch.rand(batch_size, num_frames, input_dim)
return dict(input=input, lengths=lengths)
trace_and_compare(model, data_gen, need_meta=True, need_concrete=False)
@pytest.mark.skip
def test_torchaudio_transformers():
test_conformer()
test_emformer()
if __name__ == "__main__":
test_torchaudio_transformers()

View File

@@ -1,50 +0,0 @@
import torch
from torchaudio.models.wav2vec2 import (
hubert_base,
hubert_large,
hubert_xlarge,
wav2vec2_base,
wav2vec2_large,
wav2vec2_large_lv60k,
)
from torchaudio_utils import trace_and_compare
import pytest
MODEL_LIST = [
hubert_base,
hubert_large,
hubert_xlarge,
wav2vec2_base,
wav2vec2_large,
wav2vec2_large_lv60k,
]
def _smoke_test(model, device):
model = model.to(device=device)
batch_size, num_frames = 3, 1024
def data_gen():
waveforms = torch.randn(batch_size, num_frames, device=device)
lengths = torch.randint(
low=0,
high=num_frames,
size=[
batch_size,
],
device=device,
)
return dict(waveforms=waveforms, lengths=lengths)
trace_and_compare(model, data_gen, need_meta=True, need_concrete=False)
@pytest.mark.skip("Tracing failed")
def test_wav2vec():
for model_fn in MODEL_LIST:
_smoke_test(model_fn(), 'cpu')
if __name__ == "__main__":
test_wav2vec()

View File

@@ -3,7 +3,7 @@ import torch
from colossalai.fx import symbolic_trace
def trace_and_compare(model, data_gen, need_meta=False, need_concrete=False, kwargs_transform=False):
def trace_and_compare(model, data_gen, output_transform_fn, need_meta=False, need_concrete=False):
data = data_gen()
concrete_args = data if need_concrete else {}
meta_args = {k: v.to('meta') for k, v in data.items()} if need_meta else {}
@@ -14,16 +14,15 @@ def trace_and_compare(model, data_gen, need_meta=False, need_concrete=False, kwa
with torch.no_grad():
non_fx_out = model(**data)
if kwargs_transform:
data = kwargs_transform(data)
fx_out = gm(**data)
if isinstance(fx_out, tuple):
for non_fx, fx in zip(non_fx_out, fx_out):
assert torch.allclose(
non_fx, fx, atol=1e-5), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
else:
assert torch.allclose(
fx_out, non_fx_out,
atol=1e-5), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
# compare output
transformed_fx_out = output_transform_fn(fx_out)
transformed_non_fx_out = output_transform_fn(non_fx_out)
assert len(transformed_fx_out) == len(transformed_non_fx_out)
for key, fx_output_val in transformed_fx_out.items():
non_fx_output_val = transformed_non_fx_out[key]
assert torch.allclose(fx_output_val, non_fx_output_val, atol=1e-5), \
f'{model.__class__.__name__} has inconsistent outputs, {fx_output_val} vs {non_fx_output_val}'