diff --git a/colossalai/fx/tracer/meta_patch/patched_module/normalization.py b/colossalai/fx/tracer/meta_patch/patched_module/normalization.py index 78a3620cc..120874e70 100644 --- a/colossalai/fx/tracer/meta_patch/patched_module/normalization.py +++ b/colossalai/fx/tracer/meta_patch/patched_module/normalization.py @@ -17,4 +17,14 @@ def torch_nn_normalize(self, input): assert input.dim() == 5 # normalization maintain the same shape as the input - return input.clone() \ No newline at end of file + return input.clone() + + +try: + import apex + meta_patched_module.register(apex.normalization.FusedLayerNorm)(torch_nn_normalize) + meta_patched_module.register(apex.normalization.FusedRMSNorm)(torch_nn_normalize) + meta_patched_module.register(apex.normalization.MixedFusedLayerNorm)(torch_nn_normalize) + meta_patched_module.register(apex.normalization.MixedFusedRMSNorm)(torch_nn_normalize) +except ImportError: + pass diff --git a/tests/test_fx/test_pipeline/test_hf_model/test_t5.py b/tests/test_fx/test_pipeline/test_hf_model/test_t5.py index 0b747cef6..f24dd705c 100644 --- a/tests/test_fx/test_pipeline/test_hf_model/test_t5.py +++ b/tests/test_fx/test_pipeline/test_hf_model/test_t5.py @@ -2,15 +2,6 @@ import pytest import transformers import torch from hf_utils import split_model_and_compare_output -from colossalai.fx.tracer.meta_patch import meta_patched_module -try: - import apex - - @meta_patched_module.register(apex.normalization.FusedRMSNorm) - def apex_fused_layernorm(self, input): - return torch.empty(input.shape, device='meta') -except ImportError: - pass BATCH_SIZE = 1 SEQ_LENGHT = 16 diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py index 989cc9c12..4e2614056 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py @@ -1,18 +1,8 @@ import pytest import transformers import torch -from colossalai.fx.tracer.meta_patch import meta_patched_module from utils import trace_model_and_compare_output -try: - import apex - - @meta_patched_module.register(apex.normalization.FusedRMSNorm) - def apex_fused_layernorm(self, input): - return torch.empty(input.shape, device='meta') -except ImportError: - pass - BATCH_SIZE = 1 SEQ_LENGHT = 16