mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-22 05:29:36 +00:00
[Hotfix] hotfix normalization (#6163)
* [fix] hotfix normalization * [hotfix] force doc ci test * [hotfix] fallback doc
This commit is contained in:
parent
130229fdcb
commit
fa9d0318e4
@ -76,18 +76,24 @@ if SUPPORT_NPU:
|
||||
|
||||
FusedRMSNormWithHook = NPUFusedRMSNormWithHook
|
||||
else:
|
||||
from apex.normalization import FusedRMSNorm as ApexFusedRMSNorm
|
||||
try:
|
||||
from apex.normalization import FusedRMSNorm as ApexFusedRMSNorm
|
||||
|
||||
class CUDAFusedRMSNormWithHook(ApexFusedRMSNorm):
|
||||
def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True):
|
||||
super().__init__(normalized_shape, eps, elementwise_affine)
|
||||
class CUDAFusedRMSNormWithHook(ApexFusedRMSNorm):
|
||||
def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True):
|
||||
super().__init__(normalized_shape, eps, elementwise_affine)
|
||||
|
||||
def forward(self, input):
|
||||
output = super().forward(input)
|
||||
output = hook_parameter_in_backward(output, self.weight)
|
||||
return output
|
||||
def forward(self, input):
|
||||
output = super().forward(input)
|
||||
output = hook_parameter_in_backward(output, self.weight)
|
||||
return output
|
||||
|
||||
FusedRMSNormWithHook = CUDAFusedRMSNormWithHook
|
||||
except ImportError:
|
||||
warnings.warn(
|
||||
"Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMSNorm kernel"
|
||||
)
|
||||
|
||||
FusedRMSNormWithHook = CUDAFusedRMSNormWithHook
|
||||
|
||||
FAST_LAYERNORM_SUPPORTED_SIZE = [
|
||||
1024,
|
||||
|
Loading…
Reference in New Issue
Block a user