mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-27 07:47:05 +00:00
[fx] added apex normalization to patched modules (#1300)
* [fx] added apex normalization to patched modules * remove unused imports
This commit is contained in:
parent
4165eabb1e
commit
4f4d8c3656
@ -18,3 +18,13 @@ def torch_nn_normalize(self, input):
|
|||||||
|
|
||||||
# normalization maintain the same shape as the input
|
# normalization maintain the same shape as the input
|
||||||
return input.clone()
|
return input.clone()
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
import apex
|
||||||
|
meta_patched_module.register(apex.normalization.FusedLayerNorm)(torch_nn_normalize)
|
||||||
|
meta_patched_module.register(apex.normalization.FusedRMSNorm)(torch_nn_normalize)
|
||||||
|
meta_patched_module.register(apex.normalization.MixedFusedLayerNorm)(torch_nn_normalize)
|
||||||
|
meta_patched_module.register(apex.normalization.MixedFusedRMSNorm)(torch_nn_normalize)
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
@ -2,15 +2,6 @@ import pytest
|
|||||||
import transformers
|
import transformers
|
||||||
import torch
|
import torch
|
||||||
from hf_utils import split_model_and_compare_output
|
from hf_utils import split_model_and_compare_output
|
||||||
from colossalai.fx.tracer.meta_patch import meta_patched_module
|
|
||||||
try:
|
|
||||||
import apex
|
|
||||||
|
|
||||||
@meta_patched_module.register(apex.normalization.FusedRMSNorm)
|
|
||||||
def apex_fused_layernorm(self, input):
|
|
||||||
return torch.empty(input.shape, device='meta')
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
BATCH_SIZE = 1
|
BATCH_SIZE = 1
|
||||||
SEQ_LENGHT = 16
|
SEQ_LENGHT = 16
|
||||||
|
@ -1,18 +1,8 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import transformers
|
import transformers
|
||||||
import torch
|
import torch
|
||||||
from colossalai.fx.tracer.meta_patch import meta_patched_module
|
|
||||||
from utils import trace_model_and_compare_output
|
from utils import trace_model_and_compare_output
|
||||||
|
|
||||||
try:
|
|
||||||
import apex
|
|
||||||
|
|
||||||
@meta_patched_module.register(apex.normalization.FusedRMSNorm)
|
|
||||||
def apex_fused_layernorm(self, input):
|
|
||||||
return torch.empty(input.shape, device='meta')
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
BATCH_SIZE = 1
|
BATCH_SIZE = 1
|
||||||
SEQ_LENGHT = 16
|
SEQ_LENGHT = 16
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user