diff --git a/colossalai/fx/tracer/meta_patch/patched_module/normalization.py b/colossalai/fx/tracer/meta_patch/patched_module/normalization.py index 120874e70..e83b31b67 100644 --- a/colossalai/fx/tracer/meta_patch/patched_module/normalization.py +++ b/colossalai/fx/tracer/meta_patch/patched_module/normalization.py @@ -26,5 +26,5 @@ try: meta_patched_module.register(apex.normalization.FusedRMSNorm)(torch_nn_normalize) meta_patched_module.register(apex.normalization.MixedFusedLayerNorm)(torch_nn_normalize) meta_patched_module.register(apex.normalization.MixedFusedRMSNorm)(torch_nn_normalize) -except ImportError: +except (ImportError, AttributeError): pass