ColossalAI/colossalai/_analyzer/fx/tracer/custom_leaf_module.py
YuliangLiu0306 f57d34958b
[FX] refactor experimental tracer and adapt it with hf models (#3157)
* pass gpt trace and meta_prop

* pass t5 trace and meta_prop

* [FX] refactor experimental tracer and adapt it with hf models

* pass all mainstream model zoo

* fix CI

* fix CI

* fix CI

* fix CI

* fix CI

* fix CI

* fix CI

* fix CI

* skip tests

* fix CI

* using packaging version

* polish
2023-03-22 10:40:33 +08:00

30 lines
1.1 KiB
Python

import torch
from .tracer import register_leaf_module, register_leaf_module_impl
try:
import apex
register_leaf_module(apex.normalization.FusedLayerNorm)
register_leaf_module(apex.normalization.FusedRMSNorm)
register_leaf_module(apex.normalization.MixedFusedLayerNorm)
register_leaf_module(apex.normalization.MixedFusedRMSNorm)
@register_leaf_module_impl(apex.normalization.FusedLayerNorm)
@register_leaf_module_impl(apex.normalization.FusedRMSNorm)
@register_leaf_module_impl(apex.normalization.MixedFusedLayerNorm)
@register_leaf_module_impl(apex.normalization.MixedFusedRMSNorm)
def torch_nn_normalize(self, input: torch.Tensor):
# check shape
if isinstance(self, torch.nn.BatchNorm1d):
assert input.dim() in [2, 3]
elif isinstance(self, torch.nn.BatchNorm2d):
assert input.dim() == 4
elif isinstance(self, torch.nn.BatchNorm3d):
assert input.dim() == 5
# normalization maintain the same shape as the input
return input.clone()
except (ImportError, AttributeError):
pass