mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 18:19:58 +00:00
[FX] refactor experimental tracer and adapt it with hf models (#3157)
* pass gpt trace and meta_prop * pass t5 trace and meta_prop * [FX] refactor experimental tracer and adapt it with hf models * pass all mainstream model zoo * fix CI * fix CI * fix CI * fix CI * fix CI * fix CI * fix CI * fix CI * skip tests * fix CI * using packaging version * polish
This commit is contained in:
29
colossalai/_analyzer/fx/tracer/custom_leaf_module.py
Normal file
29
colossalai/_analyzer/fx/tracer/custom_leaf_module.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import torch
|
||||
|
||||
from .tracer import register_leaf_module, register_leaf_module_impl
|
||||
|
||||
try:
|
||||
import apex
|
||||
register_leaf_module(apex.normalization.FusedLayerNorm)
|
||||
register_leaf_module(apex.normalization.FusedRMSNorm)
|
||||
register_leaf_module(apex.normalization.MixedFusedLayerNorm)
|
||||
register_leaf_module(apex.normalization.MixedFusedRMSNorm)
|
||||
|
||||
@register_leaf_module_impl(apex.normalization.FusedLayerNorm)
|
||||
@register_leaf_module_impl(apex.normalization.FusedRMSNorm)
|
||||
@register_leaf_module_impl(apex.normalization.MixedFusedLayerNorm)
|
||||
@register_leaf_module_impl(apex.normalization.MixedFusedRMSNorm)
|
||||
def torch_nn_normalize(self, input: torch.Tensor):
|
||||
# check shape
|
||||
if isinstance(self, torch.nn.BatchNorm1d):
|
||||
assert input.dim() in [2, 3]
|
||||
elif isinstance(self, torch.nn.BatchNorm2d):
|
||||
assert input.dim() == 4
|
||||
elif isinstance(self, torch.nn.BatchNorm3d):
|
||||
assert input.dim() == 5
|
||||
|
||||
# normalization maintain the same shape as the input
|
||||
return input.clone()
|
||||
|
||||
except (ImportError, AttributeError):
|
||||
pass
|
Reference in New Issue
Block a user