mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 19:13:01 +00:00
[booster] add low level zero plugin (#3594)
* [booster] add low level zero plugin * [booster] fix gemini plugin test * [booster] fix precision * [booster] add low level zero plugin test * [test] fix booster plugin test oom * [test] fix booster plugin test oom * [test] fix googlenet and inception output trans * [test] fix diffuser clip vision model * [test] fix torchaudio_wav2vec2_base * [test] fix low level zero plugin test
This commit is contained in:
@@ -36,12 +36,12 @@ def swin_s():
|
||||
|
||||
|
||||
# special output transform fn
|
||||
google_net_output_transform_fn = lambda x: dict(output=x.logits) if isinstance(x, torchvision.models.GoogLeNetOutputs
|
||||
) else dict(output=x)
|
||||
google_net_output_transform_fn = lambda x: dict(output=sum(x)) if isinstance(x, torchvision.models.GoogLeNetOutputs
|
||||
) else dict(output=x)
|
||||
swin_s_output_output_transform_fn = lambda x: {f'output{idx}': val
|
||||
for idx, val in enumerate(x)} if isinstance(x, tuple) else dict(output=x)
|
||||
inception_v3_output_transform_fn = lambda x: dict(output=x.logits) if isinstance(x, torchvision.models.InceptionOutputs
|
||||
) else dict(output=x)
|
||||
inception_v3_output_transform_fn = lambda x: dict(output=sum(x)) if isinstance(x, torchvision.models.InceptionOutputs
|
||||
) else dict(output=x)
|
||||
|
||||
model_zoo.register(name='torchvision_alexnet',
|
||||
model_fn=tm.alexnet,
|
||||
|
Reference in New Issue
Block a user