mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 19:13:01 +00:00
[booster] add low level zero plugin (#3594)
* [booster] add low level zero plugin * [booster] fix gemini plugin test * [booster] fix precision * [booster] add low level zero plugin test * [test] fix booster plugin test oom * [test] fix booster plugin test oom * [test] fix googlenet and inception output trans * [test] fix diffuser clip vision model * [test] fix torchaudio_wav2vec2_base * [test] fix low level zero plugin test
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torchaudio.models as tm
|
||||
|
||||
@@ -101,13 +103,11 @@ def tacotron_data_gen_fn():
|
||||
mel_specgram_lengths=mel_specgram_lengths)
|
||||
|
||||
|
||||
model_zoo.register(
|
||||
name='torchaudio_tacotron',
|
||||
model_fn=lambda: tm.Tacotron2(n_mels=N_MELS),
|
||||
data_gen_fn=tacotron_data_gen_fn,
|
||||
output_transform_fn=lambda outputs: dict(
|
||||
spectrogram_before=outputs[0], spectrogram_after=outputs[1], stop_tokens=outputs[2], attn_weights=outputs[3]),
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='torchaudio_tacotron',
|
||||
model_fn=lambda: tm.Tacotron2(n_mels=N_MELS),
|
||||
data_gen_fn=tacotron_data_gen_fn,
|
||||
output_transform_fn=lambda outputs: dict(summed_output=sum(x.sum() for x in outputs)),
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
|
||||
|
||||
def wav2vec_data_gen_fn():
|
||||
@@ -118,7 +118,7 @@ def wav2vec_data_gen_fn():
|
||||
|
||||
|
||||
model_zoo.register(name='torchaudio_wav2vec2_base',
|
||||
model_fn=tm.wav2vec2_base,
|
||||
model_fn=partial(tm.wav2vec2_base, encoder_layer_drop=0.0),
|
||||
data_gen_fn=wav2vec_data_gen_fn,
|
||||
output_transform_fn=transformer_output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
|
Reference in New Issue
Block a user