[booster] add low level zero plugin (#3594)

* [booster] add low level zero plugin

* [booster] fix gemini plugin test

* [booster] fix precision

* [booster] add low level zero plugin test

* [test] fix booster plugin test oom

* [test] fix booster plugin test oom

* [test] fix googlenet and inception output trans

* [test] fix diffuser clip vision model

* [test] fix torchaudio_wav2vec2_base

* [test] fix low level zero plugin test
This commit is contained in:
Hongxin Liu
2023-04-26 14:37:25 +08:00
committed by GitHub
parent b9a8dff7e5
commit 4b3240cb59
9 changed files with 476 additions and 81 deletions

View File

@@ -1,3 +1,5 @@
from functools import partial
import torch
import torchaudio.models as tm
@@ -101,13 +103,11 @@ def tacotron_data_gen_fn():
mel_specgram_lengths=mel_specgram_lengths)
model_zoo.register(
name='torchaudio_tacotron',
model_fn=lambda: tm.Tacotron2(n_mels=N_MELS),
data_gen_fn=tacotron_data_gen_fn,
output_transform_fn=lambda outputs: dict(
spectrogram_before=outputs[0], spectrogram_after=outputs[1], stop_tokens=outputs[2], attn_weights=outputs[3]),
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='torchaudio_tacotron',
model_fn=lambda: tm.Tacotron2(n_mels=N_MELS),
data_gen_fn=tacotron_data_gen_fn,
output_transform_fn=lambda outputs: dict(summed_output=sum(x.sum() for x in outputs)),
model_attribute=ModelAttribute(has_control_flow=True))
def wav2vec_data_gen_fn():
@@ -118,7 +118,7 @@ def wav2vec_data_gen_fn():
model_zoo.register(name='torchaudio_wav2vec2_base',
model_fn=tm.wav2vec2_base,
model_fn=partial(tm.wav2vec2_base, encoder_layer_drop=0.0),
data_gen_fn=wav2vec_data_gen_fn,
output_transform_fn=transformer_output_transform_fn,
model_attribute=ModelAttribute(has_control_flow=True))