mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 12:01:39 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -1,28 +1,49 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.lazy import LazyInitContext
|
||||
|
||||
__all__ = ['FusedLayerNorm', 'FusedRMSNorm']
|
||||
__all__ = ["FusedLayerNorm", "FusedRMSNorm"]
|
||||
|
||||
FAST_LAYERNORM_SUPPORTED_SIZE = [
|
||||
1024, 1536, 2048, 2304, 3072, 3840, 4096, 5120, 6144, 8192, 10240, 12288, 12800, 15360, 16384, 18432, 20480, 24576,
|
||||
25600, 30720, 32768, 40960, 49152, 65536
|
||||
1024,
|
||||
1536,
|
||||
2048,
|
||||
2304,
|
||||
3072,
|
||||
3840,
|
||||
4096,
|
||||
5120,
|
||||
6144,
|
||||
8192,
|
||||
10240,
|
||||
12288,
|
||||
12800,
|
||||
15360,
|
||||
16384,
|
||||
18432,
|
||||
20480,
|
||||
24576,
|
||||
25600,
|
||||
30720,
|
||||
32768,
|
||||
40960,
|
||||
49152,
|
||||
65536,
|
||||
]
|
||||
|
||||
|
||||
class FusedLayerNorm():
|
||||
class FusedLayerNorm:
|
||||
r"""
|
||||
This is a wrapper around the apex fused layernorm implementation. It is meant to be used only with the from_native_module interface.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
raise NotImplementedError(
|
||||
'FusedLayerNorm is not implemented as a physical class. '
|
||||
'It is meant to be used only with the from_native_module interface to wrap the fused layernorm implementation provided by apex.'
|
||||
"FusedLayerNorm is not implemented as a physical class. "
|
||||
"It is meant to be used only with the from_native_module interface to wrap the fused layernorm implementation provided by apex."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -32,10 +53,11 @@ class FusedLayerNorm():
|
||||
"""
|
||||
# check if apex is installed
|
||||
try:
|
||||
import apex
|
||||
pass
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
'Please install apex from source (https://github.com/NVIDIA/apex) to use the fused layernorm kernel')
|
||||
"Please install apex from source (https://github.com/NVIDIA/apex) to use the fused layernorm kernel"
|
||||
)
|
||||
|
||||
LazyInitContext.materialize(module)
|
||||
# get the attributes of the module
|
||||
@@ -57,23 +79,24 @@ class FusedLayerNorm():
|
||||
else:
|
||||
from apex.normalization import FusedLayerNorm as ApexFusedLayerNorm
|
||||
|
||||
layernorm = ApexFusedLayerNorm(normalized_shape, eps=eps,
|
||||
elementwise_affine=elementwise_affine).to(dtype).to(device)
|
||||
layernorm = (
|
||||
ApexFusedLayerNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine).to(dtype).to(device)
|
||||
)
|
||||
|
||||
layernorm.weight = module.weight
|
||||
layernorm.bias = module.bias
|
||||
return layernorm
|
||||
|
||||
|
||||
class FusedRMSNorm():
|
||||
class FusedRMSNorm:
|
||||
"""
|
||||
This is a wrapper around the apex fused rms norm implementation. It is meant to be used only with the from_native_module interface.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
raise NotImplementedError(
|
||||
'FusedRMSNorm is not implemented as a physical class. '
|
||||
'It is meant to be used only with the from_native_module interface to wrap the fused rms norm implementation provided by apex.'
|
||||
"FusedRMSNorm is not implemented as a physical class. "
|
||||
"It is meant to be used only with the from_native_module interface to wrap the fused rms norm implementation provided by apex."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -82,7 +105,7 @@ class FusedRMSNorm():
|
||||
from apex.normalization import FusedRMSNorm as ApexFusedRMSNorm
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
'Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMS normalization kernel'
|
||||
"Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMS normalization kernel"
|
||||
)
|
||||
|
||||
LazyInitContext.materialize(module)
|
||||
|
Reference in New Issue
Block a user