mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 19:13:01 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -23,24 +23,31 @@ def conformer_data_gen_fn():
|
||||
|
||||
transformer_output_transform_fn = lambda outputs: dict(frames=outputs[0], lengths=outputs[1])
|
||||
|
||||
model_zoo.register(name='torchaudio_conformer',
|
||||
model_fn=lambda: tm.Conformer(
|
||||
input_dim=INPUT_DIM, num_heads=4, ffn_dim=128, num_layers=4, depthwise_conv_kernel_size=31),
|
||||
data_gen_fn=conformer_data_gen_fn,
|
||||
output_transform_fn=transformer_output_transform_fn)
|
||||
model_zoo.register(
|
||||
name="torchaudio_conformer",
|
||||
model_fn=lambda: tm.Conformer(
|
||||
input_dim=INPUT_DIM, num_heads=4, ffn_dim=128, num_layers=4, depthwise_conv_kernel_size=31
|
||||
),
|
||||
data_gen_fn=conformer_data_gen_fn,
|
||||
output_transform_fn=transformer_output_transform_fn,
|
||||
)
|
||||
|
||||
single_output_transform_fn = lambda output: dict(output=output)
|
||||
|
||||
model_zoo.register(name='torchaudio_convtasnet',
|
||||
model_fn=tm.ConvTasNet,
|
||||
data_gen_fn=lambda: dict(input=torch.rand(4, 1, 8)),
|
||||
output_transform_fn=single_output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(
|
||||
name="torchaudio_convtasnet",
|
||||
model_fn=tm.ConvTasNet,
|
||||
data_gen_fn=lambda: dict(input=torch.rand(4, 1, 8)),
|
||||
output_transform_fn=single_output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
|
||||
model_zoo.register(name='torchaudio_deepspeech',
|
||||
model_fn=lambda: tm.DeepSpeech(IN_FEATURES, n_hidden=128, n_class=4),
|
||||
data_gen_fn=lambda: dict(x=torch.rand(4, 1, 10, IN_FEATURES)),
|
||||
output_transform_fn=single_output_transform_fn)
|
||||
model_zoo.register(
|
||||
name="torchaudio_deepspeech",
|
||||
model_fn=lambda: tm.DeepSpeech(IN_FEATURES, n_hidden=128, n_class=4),
|
||||
data_gen_fn=lambda: dict(x=torch.rand(4, 1, 10, IN_FEATURES)),
|
||||
output_transform_fn=single_output_transform_fn,
|
||||
)
|
||||
|
||||
|
||||
def emformer_data_gen_fn():
|
||||
@@ -50,21 +57,26 @@ def emformer_data_gen_fn():
|
||||
|
||||
|
||||
model_zoo.register(
|
||||
name='torchaudio_emformer',
|
||||
name="torchaudio_emformer",
|
||||
model_fn=lambda: tm.Emformer(input_dim=IN_FEATURES, num_heads=4, ffn_dim=128, num_layers=4, segment_length=4),
|
||||
data_gen_fn=emformer_data_gen_fn,
|
||||
output_transform_fn=transformer_output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
|
||||
model_zoo.register(name='torchaudio_wav2letter_waveform',
|
||||
model_fn=lambda: tm.Wav2Letter(input_type='waveform', num_features=40),
|
||||
data_gen_fn=lambda: dict(x=torch.rand(4, 40, 400)),
|
||||
output_transform_fn=single_output_transform_fn)
|
||||
model_zoo.register(
|
||||
name="torchaudio_wav2letter_waveform",
|
||||
model_fn=lambda: tm.Wav2Letter(input_type="waveform", num_features=40),
|
||||
data_gen_fn=lambda: dict(x=torch.rand(4, 40, 400)),
|
||||
output_transform_fn=single_output_transform_fn,
|
||||
)
|
||||
|
||||
model_zoo.register(name='torchaudio_wav2letter_mfcc',
|
||||
model_fn=lambda: tm.Wav2Letter(input_type='mfcc', num_features=40),
|
||||
data_gen_fn=lambda: dict(x=torch.rand(4, 40, 400)),
|
||||
output_transform_fn=single_output_transform_fn)
|
||||
model_zoo.register(
|
||||
name="torchaudio_wav2letter_mfcc",
|
||||
model_fn=lambda: tm.Wav2Letter(input_type="mfcc", num_features=40),
|
||||
data_gen_fn=lambda: dict(x=torch.rand(4, 40, 400)),
|
||||
output_transform_fn=single_output_transform_fn,
|
||||
)
|
||||
|
||||
|
||||
def wavernn_data_gen_fn():
|
||||
@@ -73,20 +85,24 @@ def wavernn_data_gen_fn():
|
||||
return dict(waveform=waveform, specgram=specgram)
|
||||
|
||||
|
||||
model_zoo.register(name='torchaudio_wavernn',
|
||||
model_fn=lambda: tm.WaveRNN(upsample_scales=[2, 2, 5],
|
||||
n_classes=N_CLASSES,
|
||||
hop_length=HOP_LENGTH,
|
||||
kernel_size=KERNEL_SIZE,
|
||||
n_freq=N_FREQ,
|
||||
n_res_block=2,
|
||||
n_rnn=64,
|
||||
n_fc=64,
|
||||
n_hidden=16,
|
||||
n_output=16),
|
||||
data_gen_fn=wavernn_data_gen_fn,
|
||||
output_transform_fn=single_output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(
|
||||
name="torchaudio_wavernn",
|
||||
model_fn=lambda: tm.WaveRNN(
|
||||
upsample_scales=[2, 2, 5],
|
||||
n_classes=N_CLASSES,
|
||||
hop_length=HOP_LENGTH,
|
||||
kernel_size=KERNEL_SIZE,
|
||||
n_freq=N_FREQ,
|
||||
n_res_block=2,
|
||||
n_rnn=64,
|
||||
n_fc=64,
|
||||
n_hidden=16,
|
||||
n_output=16,
|
||||
),
|
||||
data_gen_fn=wavernn_data_gen_fn,
|
||||
output_transform_fn=single_output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
|
||||
|
||||
def tacotron_data_gen_fn():
|
||||
@@ -97,17 +113,18 @@ def tacotron_data_gen_fn():
|
||||
token_lengths = max_text_length * torch.ones((n_batch,))
|
||||
mel_specgram = torch.rand(n_batch, N_MELS, max_mel_specgram_length)
|
||||
mel_specgram_lengths = max_mel_specgram_length * torch.ones((n_batch,))
|
||||
return dict(tokens=tokens,
|
||||
token_lengths=token_lengths,
|
||||
mel_specgram=mel_specgram,
|
||||
mel_specgram_lengths=mel_specgram_lengths)
|
||||
return dict(
|
||||
tokens=tokens, token_lengths=token_lengths, mel_specgram=mel_specgram, mel_specgram_lengths=mel_specgram_lengths
|
||||
)
|
||||
|
||||
|
||||
model_zoo.register(name='torchaudio_tacotron',
|
||||
model_fn=lambda: tm.Tacotron2(n_mels=N_MELS),
|
||||
data_gen_fn=tacotron_data_gen_fn,
|
||||
output_transform_fn=lambda outputs: dict(summed_output=sum(x.sum() for x in outputs)),
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(
|
||||
name="torchaudio_tacotron",
|
||||
model_fn=lambda: tm.Tacotron2(n_mels=N_MELS),
|
||||
data_gen_fn=tacotron_data_gen_fn,
|
||||
output_transform_fn=lambda outputs: dict(summed_output=sum(x.sum() for x in outputs)),
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
|
||||
|
||||
def wav2vec_data_gen_fn():
|
||||
@@ -117,14 +134,18 @@ def wav2vec_data_gen_fn():
|
||||
return dict(waveforms=waveforms, lengths=lengths)
|
||||
|
||||
|
||||
model_zoo.register(name='torchaudio_wav2vec2_base',
|
||||
model_fn=partial(tm.wav2vec2_base, encoder_layer_drop=0.0),
|
||||
data_gen_fn=wav2vec_data_gen_fn,
|
||||
output_transform_fn=transformer_output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(
|
||||
name="torchaudio_wav2vec2_base",
|
||||
model_fn=partial(tm.wav2vec2_base, encoder_layer_drop=0.0),
|
||||
data_gen_fn=wav2vec_data_gen_fn,
|
||||
output_transform_fn=transformer_output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
|
||||
model_zoo.register(name='torchaudio_hubert_base',
|
||||
model_fn=tm.hubert_base,
|
||||
data_gen_fn=wav2vec_data_gen_fn,
|
||||
output_transform_fn=transformer_output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(
|
||||
name="torchaudio_hubert_base",
|
||||
model_fn=tm.hubert_base,
|
||||
data_gen_fn=wav2vec_data_gen_fn,
|
||||
output_transform_fn=transformer_output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
|
Reference in New Issue
Block a user