mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-24 14:33:20 +00:00
[Hotfix] hotfix normalization (#6163)
* [fix] hotfix normalization * [hotfix] force doc ci test * [hotfix] fallback doc
This commit is contained in:
parent
130229fdcb
commit
fa9d0318e4
@ -76,6 +76,7 @@ if SUPPORT_NPU:
|
|||||||
|
|
||||||
FusedRMSNormWithHook = NPUFusedRMSNormWithHook
|
FusedRMSNormWithHook = NPUFusedRMSNormWithHook
|
||||||
else:
|
else:
|
||||||
|
try:
|
||||||
from apex.normalization import FusedRMSNorm as ApexFusedRMSNorm
|
from apex.normalization import FusedRMSNorm as ApexFusedRMSNorm
|
||||||
|
|
||||||
class CUDAFusedRMSNormWithHook(ApexFusedRMSNorm):
|
class CUDAFusedRMSNormWithHook(ApexFusedRMSNorm):
|
||||||
@ -88,6 +89,11 @@ else:
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
FusedRMSNormWithHook = CUDAFusedRMSNormWithHook
|
FusedRMSNormWithHook = CUDAFusedRMSNormWithHook
|
||||||
|
except ImportError:
|
||||||
|
warnings.warn(
|
||||||
|
"Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMSNorm kernel"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
FAST_LAYERNORM_SUPPORTED_SIZE = [
|
FAST_LAYERNORM_SUPPORTED_SIZE = [
|
||||||
1024,
|
1024,
|
||||||
|
Loading…
Reference in New Issue
Block a user