mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-10 05:20:33 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -27,7 +27,7 @@ from colossalai.utils.cuda import get_current_device
|
||||
from ..base_layer import ParallelLayer
|
||||
from ..colossalai_layer._utils import ColossalaiModule
|
||||
from ..utils import divide, set_tensor_parallel_attribute_by_partition
|
||||
from ..vanilla import VanillaLayerNorm, VanillaPatchEmbedding
|
||||
from ..vanilla import VanillaPatchEmbedding
|
||||
from ._operation import linear_with_async_comm
|
||||
from ._utils import (
|
||||
gather_forward_split_backward,
|
||||
@@ -41,6 +41,7 @@ from ._utils import (
|
||||
Fast_LN = None
|
||||
try:
|
||||
from apex.contrib.layer_norm.layer_norm import FastLayerNorm
|
||||
|
||||
Fast_LN = FastLayerNorm
|
||||
except ImportError:
|
||||
pass
|
||||
@@ -67,33 +68,39 @@ class Linear1D(ColossalaiModule):
|
||||
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
dtype: torch.dtype = None,
|
||||
gather_output: bool = False,
|
||||
skip_bias_add: bool = False,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
dtype: torch.dtype = None,
|
||||
gather_output: bool = False,
|
||||
skip_bias_add: bool = False,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
):
|
||||
parallel_input = get_parallel_input()
|
||||
if not parallel_input and not gather_output:
|
||||
layer = Linear1D_Col(in_features,
|
||||
out_features,
|
||||
bias=bias,
|
||||
dtype=dtype,
|
||||
skip_bias_add=skip_bias_add,
|
||||
weight_initializer=weight_initializer,
|
||||
bias_initializer=bias_initializer)
|
||||
layer = Linear1D_Col(
|
||||
in_features,
|
||||
out_features,
|
||||
bias=bias,
|
||||
dtype=dtype,
|
||||
skip_bias_add=skip_bias_add,
|
||||
weight_initializer=weight_initializer,
|
||||
bias_initializer=bias_initializer,
|
||||
)
|
||||
else:
|
||||
layer = Linear1D_Row(in_features,
|
||||
out_features,
|
||||
bias=bias,
|
||||
dtype=dtype,
|
||||
parallel_input=parallel_input,
|
||||
skip_bias_add=skip_bias_add,
|
||||
weight_initializer=weight_initializer,
|
||||
bias_initializer=bias_initializer)
|
||||
layer = Linear1D_Row(
|
||||
in_features,
|
||||
out_features,
|
||||
bias=bias,
|
||||
dtype=dtype,
|
||||
parallel_input=parallel_input,
|
||||
skip_bias_add=skip_bias_add,
|
||||
weight_initializer=weight_initializer,
|
||||
bias_initializer=bias_initializer,
|
||||
)
|
||||
super().__init__(layer)
|
||||
|
||||
|
||||
@@ -114,8 +121,30 @@ class LayerNorm1D(ColossalaiModule):
|
||||
"""
|
||||
|
||||
_fast_ln_supported_sizes = [
|
||||
1024, 1536, 2048, 2304, 3072, 3840, 4096, 5120, 6144, 8192, 10240, 12288, 12800, 15360, 16384, 18432, 20480,
|
||||
24576, 25600, 30720, 32768, 40960, 49152, 65536
|
||||
1024,
|
||||
1536,
|
||||
2048,
|
||||
2304,
|
||||
3072,
|
||||
3840,
|
||||
4096,
|
||||
5120,
|
||||
6144,
|
||||
8192,
|
||||
10240,
|
||||
12288,
|
||||
12800,
|
||||
15360,
|
||||
16384,
|
||||
18432,
|
||||
20480,
|
||||
24576,
|
||||
25600,
|
||||
30720,
|
||||
32768,
|
||||
40960,
|
||||
49152,
|
||||
65536,
|
||||
]
|
||||
|
||||
def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None):
|
||||
@@ -125,6 +154,7 @@ class LayerNorm1D(ColossalaiModule):
|
||||
norm = None
|
||||
try:
|
||||
from apex.normalization import FusedLayerNorm
|
||||
|
||||
norm = FusedLayerNorm(normalized_shape, eps=eps).to(dtype)
|
||||
except ImportError:
|
||||
norm = LayerNorm(normalized_shape, eps=eps).to(dtype)
|
||||
@@ -132,8 +162,8 @@ class LayerNorm1D(ColossalaiModule):
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
weight_key = prefix + "weight"
|
||||
bias_key = prefix + "bias"
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
# weight
|
||||
weight = state_dict.pop(weight_key, None)
|
||||
@@ -171,14 +201,16 @@ class Classifier1D(ParallelLayer):
|
||||
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_features: int,
|
||||
num_classes: int,
|
||||
weight: Parameter = None,
|
||||
bias: bool = True,
|
||||
dtype: torch.dtype = None,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
num_classes: int,
|
||||
weight: Parameter = None,
|
||||
bias: bool = True,
|
||||
dtype: torch.dtype = None,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
):
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
self.num_classes = num_classes
|
||||
@@ -189,7 +221,7 @@ class Classifier1D(ParallelLayer):
|
||||
|
||||
# Parameters.
|
||||
# Initialize weight.
|
||||
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
|
||||
factory_kwargs = {"device": get_current_device(), "dtype": dtype}
|
||||
if weight is not None:
|
||||
self.weight = weight
|
||||
self.has_weight = False
|
||||
@@ -221,8 +253,8 @@ class Classifier1D(ParallelLayer):
|
||||
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
weight_key = prefix + "weight"
|
||||
bias_key = prefix + "bias"
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
# weight
|
||||
if self.has_weight:
|
||||
@@ -235,50 +267,46 @@ class Classifier1D(ParallelLayer):
|
||||
if bias is not None:
|
||||
local_state[bias_key] = bias
|
||||
|
||||
local_state = partition_tensor_parallel_state_dict(local_state,
|
||||
ParallelMode.PARALLEL_1D,
|
||||
dims={
|
||||
weight_key: -1,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: False
|
||||
})
|
||||
local_state = partition_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
ParallelMode.PARALLEL_1D,
|
||||
dims={weight_key: -1, bias_key: 0},
|
||||
partition_states={weight_key: True, bias_key: False},
|
||||
)
|
||||
super()._load_from_global_state_dict(local_state, prefix, *args)
|
||||
|
||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
weight_key = prefix + "weight"
|
||||
bias_key = prefix + "bias"
|
||||
local_state = OrderedDict()
|
||||
if self.has_weight:
|
||||
local_state[weight_key] = self.weight
|
||||
if self.bias is not None:
|
||||
local_state[bias_key] = self.bias
|
||||
local_state = gather_tensor_parallel_state_dict(local_state,
|
||||
ParallelMode.PARALLEL_1D,
|
||||
dims={
|
||||
weight_key: -1,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: False
|
||||
},
|
||||
keep_vars=keep_vars)
|
||||
local_state = gather_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
ParallelMode.PARALLEL_1D,
|
||||
dims={weight_key: -1, bias_key: 0},
|
||||
partition_states={weight_key: True, bias_key: False},
|
||||
keep_vars=keep_vars,
|
||||
)
|
||||
destination.update(local_state)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
# Set up backprop all-reduce.
|
||||
if self.parallel_input:
|
||||
assert input_.shape[-1] == self.weight.shape[-1], \
|
||||
'Invalid shapes in Classifier1D forward: input={}, weight={}. Expected last dim of input {}.'.format(
|
||||
input_.shape, self.weight.shape, self.weight.shape[-1])
|
||||
assert (
|
||||
input_.shape[-1] == self.weight.shape[-1]
|
||||
), "Invalid shapes in Classifier1D forward: input={}, weight={}. Expected last dim of input {}.".format(
|
||||
input_.shape, self.weight.shape, self.weight.shape[-1]
|
||||
)
|
||||
input_ = input_
|
||||
else:
|
||||
assert divide(input_.shape[-1], gpc.tensor_parallel_size) == self.weight.shape[-1], \
|
||||
'Invalid shapes in Classifier1D forward: input={}, weight={}. Expected last dim of input {}.'.format(
|
||||
input_.shape, self.weight.shape, self.weight.shape[-1] * gpc.tensor_parallel_size)
|
||||
assert (
|
||||
divide(input_.shape[-1], gpc.tensor_parallel_size) == self.weight.shape[-1]
|
||||
), "Invalid shapes in Classifier1D forward: input={}, weight={}. Expected last dim of input {}.".format(
|
||||
input_.shape, self.weight.shape, self.weight.shape[-1] * gpc.tensor_parallel_size
|
||||
)
|
||||
input_ = split_forward_gather_backward(input_, ParallelMode.PARALLEL_1D, dim=-1)
|
||||
|
||||
output_parallel = F.linear(input_, self.weight)
|
||||
@@ -307,15 +335,17 @@ class VocabParallelClassifier1D(ParallelLayer):
|
||||
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_features: int,
|
||||
num_classes: int,
|
||||
weight: Parameter = None,
|
||||
bias: bool = True,
|
||||
dtype: torch.dtype = None,
|
||||
gather_output: bool = False,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
num_classes: int,
|
||||
weight: Parameter = None,
|
||||
bias: bool = True,
|
||||
dtype: torch.dtype = None,
|
||||
gather_output: bool = False,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
):
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
self.num_classes = num_classes
|
||||
@@ -327,7 +357,7 @@ class VocabParallelClassifier1D(ParallelLayer):
|
||||
|
||||
# Parameters.
|
||||
# Initialize weight.
|
||||
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
|
||||
factory_kwargs = {"device": get_current_device(), "dtype": dtype}
|
||||
if weight is not None:
|
||||
self.weight = weight
|
||||
self.has_weight = False
|
||||
@@ -360,8 +390,8 @@ class VocabParallelClassifier1D(ParallelLayer):
|
||||
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
weight_key = prefix + "weight"
|
||||
bias_key = prefix + "bias"
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
# weight
|
||||
if self.has_weight:
|
||||
@@ -374,43 +404,37 @@ class VocabParallelClassifier1D(ParallelLayer):
|
||||
if bias is not None:
|
||||
local_state[bias_key] = bias
|
||||
|
||||
local_state = partition_tensor_parallel_state_dict(local_state,
|
||||
ParallelMode.PARALLEL_1D,
|
||||
dims={
|
||||
weight_key: 0,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: True
|
||||
})
|
||||
local_state = partition_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
ParallelMode.PARALLEL_1D,
|
||||
dims={weight_key: 0, bias_key: 0},
|
||||
partition_states={weight_key: True, bias_key: True},
|
||||
)
|
||||
super()._load_from_global_state_dict(local_state, prefix, *args)
|
||||
|
||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
weight_key = prefix + "weight"
|
||||
bias_key = prefix + "bias"
|
||||
local_state = OrderedDict()
|
||||
if self.has_weight:
|
||||
local_state[weight_key] = self.weight
|
||||
if self.bias is not None:
|
||||
local_state[bias_key] = self.bias
|
||||
local_state = gather_tensor_parallel_state_dict(local_state,
|
||||
ParallelMode.PARALLEL_1D,
|
||||
dims={
|
||||
weight_key: 0,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: True
|
||||
},
|
||||
keep_vars=keep_vars)
|
||||
local_state = gather_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
ParallelMode.PARALLEL_1D,
|
||||
dims={weight_key: 0, bias_key: 0},
|
||||
partition_states={weight_key: True, bias_key: True},
|
||||
keep_vars=keep_vars,
|
||||
)
|
||||
destination.update(local_state)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
assert input_.shape[-1] == self.weight.shape[-1], \
|
||||
'Invalid shapes in VocabParallelClassifier1D forward: input={}, weight={}. Expected last dim of input {}.'.format(
|
||||
input_.shape, self.weight.shape, self.weight.shape[-1])
|
||||
assert (
|
||||
input_.shape[-1] == self.weight.shape[-1]
|
||||
), "Invalid shapes in VocabParallelClassifier1D forward: input={}, weight={}. Expected last dim of input {}.".format(
|
||||
input_.shape, self.weight.shape, self.weight.shape[-1]
|
||||
)
|
||||
# Set up backprop all-reduce.
|
||||
input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D)
|
||||
# Matrix multiply.
|
||||
@@ -449,15 +473,17 @@ class Linear1D_Col(ParallelLayer):
|
||||
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
dtype: torch.dtype = None,
|
||||
gather_output: bool = False,
|
||||
skip_bias_add: bool = False,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
dtype: torch.dtype = None,
|
||||
gather_output: bool = False,
|
||||
skip_bias_add: bool = False,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Keep input parameters
|
||||
@@ -467,13 +493,13 @@ class Linear1D_Col(ParallelLayer):
|
||||
self.skip_bias_add = skip_bias_add
|
||||
|
||||
if skip_bias_add and not bias:
|
||||
raise ValueError('cannot skip bias addition if bias is None')
|
||||
raise ValueError("cannot skip bias addition if bias is None")
|
||||
|
||||
self.out_features_per_partition = divide(out_features, gpc.tensor_parallel_size)
|
||||
|
||||
# Parameters.
|
||||
# Initialize weight.
|
||||
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
|
||||
factory_kwargs = {"device": get_current_device(), "dtype": dtype}
|
||||
self.weight = Parameter(torch.empty(self.out_features_per_partition, self.in_features, **factory_kwargs))
|
||||
|
||||
if bias:
|
||||
@@ -500,8 +526,8 @@ class Linear1D_Col(ParallelLayer):
|
||||
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
weight_key = prefix + "weight"
|
||||
bias_key = prefix + "bias"
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
# weight
|
||||
weight = state_dict.pop(weight_key, None)
|
||||
@@ -513,41 +539,35 @@ class Linear1D_Col(ParallelLayer):
|
||||
if bias is not None:
|
||||
local_state[bias_key] = bias
|
||||
|
||||
local_state = partition_tensor_parallel_state_dict(local_state,
|
||||
ParallelMode.PARALLEL_1D,
|
||||
dims={
|
||||
weight_key: 0,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: True
|
||||
})
|
||||
local_state = partition_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
ParallelMode.PARALLEL_1D,
|
||||
dims={weight_key: 0, bias_key: 0},
|
||||
partition_states={weight_key: True, bias_key: True},
|
||||
)
|
||||
super()._load_from_global_state_dict(local_state, prefix, *args)
|
||||
|
||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
weight_key = prefix + "weight"
|
||||
bias_key = prefix + "bias"
|
||||
local_state = OrderedDict({weight_key: self.weight})
|
||||
if self.bias is not None:
|
||||
local_state[bias_key] = self.bias
|
||||
local_state = gather_tensor_parallel_state_dict(local_state,
|
||||
ParallelMode.PARALLEL_1D,
|
||||
dims={
|
||||
weight_key: 0,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: True
|
||||
},
|
||||
keep_vars=keep_vars)
|
||||
local_state = gather_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
ParallelMode.PARALLEL_1D,
|
||||
dims={weight_key: 0, bias_key: 0},
|
||||
partition_states={weight_key: True, bias_key: True},
|
||||
keep_vars=keep_vars,
|
||||
)
|
||||
destination.update(local_state)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
assert input_.shape[-1] == self.weight.shape[-1], \
|
||||
'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format(
|
||||
input_.shape, self.weight.shape, self.weight.shape[-1])
|
||||
assert (
|
||||
input_.shape[-1] == self.weight.shape[-1]
|
||||
), "Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.".format(
|
||||
input_.shape, self.weight.shape, self.weight.shape[-1]
|
||||
)
|
||||
# Set up backprop all-reduce.
|
||||
# input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D)
|
||||
input_parallel = input_
|
||||
@@ -569,7 +589,7 @@ class Linear1D_Col(ParallelLayer):
|
||||
|
||||
@LAYERS.register_module
|
||||
class Linear1D_Row(ParallelLayer):
|
||||
r""" Linear layer with row parallelism
|
||||
r"""Linear layer with row parallelism
|
||||
|
||||
Args:
|
||||
in_features (int): size of each input sample.
|
||||
@@ -588,16 +608,18 @@ class Linear1D_Row(ParallelLayer):
|
||||
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
dtype: torch.dtype = None,
|
||||
parallel_input: bool = True,
|
||||
skip_bias_add: bool = False,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
stream_chunk_num: int = 1):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
dtype: torch.dtype = None,
|
||||
parallel_input: bool = True,
|
||||
skip_bias_add: bool = False,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
stream_chunk_num: int = 1,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.stream_chunk_num = stream_chunk_num
|
||||
@@ -609,14 +631,14 @@ class Linear1D_Row(ParallelLayer):
|
||||
self.skip_bias_add = skip_bias_add
|
||||
|
||||
if skip_bias_add and not bias:
|
||||
raise ValueError('cannot skip bias addition if bias is None')
|
||||
raise ValueError("cannot skip bias addition if bias is None")
|
||||
|
||||
# Divide the weight matrix along the last dimension.
|
||||
self.input_size_per_partition = divide(in_features, gpc.tensor_parallel_size)
|
||||
|
||||
# Parameters.
|
||||
# Initialize weight.
|
||||
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
|
||||
factory_kwargs = {"device": get_current_device(), "dtype": dtype}
|
||||
self.weight = Parameter(torch.empty(self.out_features, self.input_size_per_partition, **factory_kwargs))
|
||||
|
||||
if self.stream_chunk_num > 1:
|
||||
@@ -647,8 +669,8 @@ class Linear1D_Row(ParallelLayer):
|
||||
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
weight_key = prefix + "weight"
|
||||
bias_key = prefix + "bias"
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
# weight
|
||||
weight = state_dict.pop(weight_key, None)
|
||||
@@ -660,48 +682,44 @@ class Linear1D_Row(ParallelLayer):
|
||||
if bias is not None:
|
||||
local_state[bias_key] = bias
|
||||
|
||||
local_state = partition_tensor_parallel_state_dict(local_state,
|
||||
ParallelMode.PARALLEL_1D,
|
||||
dims={
|
||||
weight_key: -1,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: False
|
||||
})
|
||||
local_state = partition_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
ParallelMode.PARALLEL_1D,
|
||||
dims={weight_key: -1, bias_key: 0},
|
||||
partition_states={weight_key: True, bias_key: False},
|
||||
)
|
||||
super()._load_from_global_state_dict(local_state, prefix, *args)
|
||||
|
||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
weight_key = prefix + "weight"
|
||||
bias_key = prefix + "bias"
|
||||
local_state = OrderedDict({weight_key: self.weight})
|
||||
if self.bias is not None:
|
||||
local_state[bias_key] = self.bias
|
||||
local_state = gather_tensor_parallel_state_dict(local_state,
|
||||
ParallelMode.PARALLEL_1D,
|
||||
dims={
|
||||
weight_key: -1,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: False
|
||||
},
|
||||
keep_vars=keep_vars)
|
||||
local_state = gather_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
ParallelMode.PARALLEL_1D,
|
||||
dims={weight_key: -1, bias_key: 0},
|
||||
partition_states={weight_key: True, bias_key: False},
|
||||
keep_vars=keep_vars,
|
||||
)
|
||||
destination.update(local_state)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
# Set up backprop all-reduce.
|
||||
if self.parallel_input:
|
||||
assert input_.shape[-1] == self.weight.shape[-1], \
|
||||
'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format(
|
||||
input_.shape, self.weight.shape, self.weight.shape[-1])
|
||||
assert (
|
||||
input_.shape[-1] == self.weight.shape[-1]
|
||||
), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.".format(
|
||||
input_.shape, self.weight.shape, self.weight.shape[-1]
|
||||
)
|
||||
input_ = input_
|
||||
else:
|
||||
assert divide(input_.shape[-1], gpc.tensor_parallel_size) == self.weight.shape[-1], \
|
||||
'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format(
|
||||
input_.shape, self.weight.shape, self.weight.shape[-1] * gpc.tensor_parallel_size)
|
||||
assert (
|
||||
divide(input_.shape[-1], gpc.tensor_parallel_size) == self.weight.shape[-1]
|
||||
), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.".format(
|
||||
input_.shape, self.weight.shape, self.weight.shape[-1] * gpc.tensor_parallel_size
|
||||
)
|
||||
input_ = split_forward_gather_backward(input_, ParallelMode.PARALLEL_1D, dim=-1)
|
||||
|
||||
if self.stream_chunk_num > 1:
|
||||
@@ -712,9 +730,9 @@ class Linear1D_Row(ParallelLayer):
|
||||
handle_list = []
|
||||
for i in range(self.stream_chunk_num):
|
||||
output_parallel_list[i] = F.linear(input_, self.weight_list[i])
|
||||
handle = torch.distributed.all_reduce(output_parallel_list[i],
|
||||
group=gpc.get_group(ParallelMode.PARALLEL_1D),
|
||||
async_op=True)
|
||||
handle = torch.distributed.all_reduce(
|
||||
output_parallel_list[i], group=gpc.get_group(ParallelMode.PARALLEL_1D), async_op=True
|
||||
)
|
||||
handle_list.append(handle)
|
||||
# output_parallel_list[i] = reduce_input(output_parallel_list[i], ParallelMode.PARALLEL_1D)
|
||||
for handle in handle_list:
|
||||
@@ -763,14 +781,16 @@ class Embedding1D(ParallelLayer):
|
||||
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
padding_idx: int = None,
|
||||
dtype: torch.dtype = None,
|
||||
weight_initializer: Callable = init.normal_(),
|
||||
*args,
|
||||
**kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
padding_idx: int = None,
|
||||
dtype: torch.dtype = None,
|
||||
weight_initializer: Callable = init.normal_(),
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.num_embeddings = num_embeddings
|
||||
@@ -782,7 +802,8 @@ class Embedding1D(ParallelLayer):
|
||||
self.embed_kwargs = kwargs
|
||||
|
||||
self.weight = Parameter(
|
||||
torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype))
|
||||
torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype)
|
||||
)
|
||||
|
||||
self.reset_parameters(weight_initializer)
|
||||
self._set_tensor_parallel_attributes()
|
||||
@@ -804,31 +825,31 @@ class Embedding1D(ParallelLayer):
|
||||
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
weight_key = prefix + "weight"
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
# weight
|
||||
weight = state_dict.pop(weight_key, None)
|
||||
if weight is not None:
|
||||
local_state[weight_key] = weight
|
||||
|
||||
local_state = partition_tensor_parallel_state_dict(local_state,
|
||||
ParallelMode.PARALLEL_1D,
|
||||
dims={weight_key: -1},
|
||||
partition_states={weight_key: True})
|
||||
local_state = partition_tensor_parallel_state_dict(
|
||||
local_state, ParallelMode.PARALLEL_1D, dims={weight_key: -1}, partition_states={weight_key: True}
|
||||
)
|
||||
super()._load_from_global_state_dict(local_state, prefix, *args)
|
||||
|
||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
weight_key = prefix + "weight"
|
||||
local_state = OrderedDict({weight_key: self.weight})
|
||||
local_state = gather_tensor_parallel_state_dict(local_state,
|
||||
ParallelMode.PARALLEL_1D,
|
||||
dims={weight_key: -1},
|
||||
partition_states={weight_key: True},
|
||||
keep_vars=keep_vars)
|
||||
local_state = gather_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
ParallelMode.PARALLEL_1D,
|
||||
dims={weight_key: -1},
|
||||
partition_states={weight_key: True},
|
||||
keep_vars=keep_vars,
|
||||
)
|
||||
destination.update(local_state)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
|
||||
output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
|
||||
|
||||
output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1)
|
||||
@@ -867,14 +888,16 @@ class VocabParallelEmbedding1D(ParallelLayer):
|
||||
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
padding_idx: int = None,
|
||||
dtype: torch.dtype = None,
|
||||
weight_initializer: Callable = init.normal_(),
|
||||
*args,
|
||||
**kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
padding_idx: int = None,
|
||||
dtype: torch.dtype = None,
|
||||
weight_initializer: Callable = init.normal_(),
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_embeddings = num_embeddings
|
||||
self.embed_dim = embedding_dim
|
||||
@@ -889,7 +912,8 @@ class VocabParallelEmbedding1D(ParallelLayer):
|
||||
self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition
|
||||
|
||||
self.weight = Parameter(
|
||||
torch.empty((self.num_embeddings_per_partition, self.embed_dim), device=get_current_device(), dtype=dtype))
|
||||
torch.empty((self.num_embeddings_per_partition, self.embed_dim), device=get_current_device(), dtype=dtype)
|
||||
)
|
||||
|
||||
self.reset_parameters(weight_initializer)
|
||||
self._set_tensor_parallel_attributes()
|
||||
@@ -906,34 +930,38 @@ class VocabParallelEmbedding1D(ParallelLayer):
|
||||
self._fill_padding_idx_with_zero()
|
||||
|
||||
def _fill_padding_idx_with_zero(self) -> None:
|
||||
if self.padding_idx is not None and \
|
||||
self.padding_idx >= self.vocab_start_index and self.padding_idx < self.vocab_end_index:
|
||||
if (
|
||||
self.padding_idx is not None
|
||||
and self.padding_idx >= self.vocab_start_index
|
||||
and self.padding_idx < self.vocab_end_index
|
||||
):
|
||||
with torch.no_grad():
|
||||
self.weight[self.padding_idx - self.vocab_start_index].fill_(0)
|
||||
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
weight_key = prefix + "weight"
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
# weight
|
||||
weight = state_dict.pop(weight_key, None)
|
||||
if weight is not None:
|
||||
local_state[weight_key] = weight
|
||||
|
||||
local_state = partition_tensor_parallel_state_dict(local_state,
|
||||
ParallelMode.PARALLEL_1D,
|
||||
dims={weight_key: 0},
|
||||
partition_states={weight_key: True})
|
||||
local_state = partition_tensor_parallel_state_dict(
|
||||
local_state, ParallelMode.PARALLEL_1D, dims={weight_key: 0}, partition_states={weight_key: True}
|
||||
)
|
||||
super()._load_from_global_state_dict(local_state, prefix, *args)
|
||||
|
||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
weight_key = prefix + "weight"
|
||||
local_state = OrderedDict({weight_key: self.weight})
|
||||
local_state = gather_tensor_parallel_state_dict(local_state,
|
||||
ParallelMode.PARALLEL_1D,
|
||||
dims={weight_key: 0},
|
||||
partition_states={weight_key: True},
|
||||
keep_vars=keep_vars)
|
||||
local_state = gather_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
ParallelMode.PARALLEL_1D,
|
||||
dims={weight_key: 0},
|
||||
partition_states={weight_key: True},
|
||||
keep_vars=keep_vars,
|
||||
)
|
||||
destination.update(local_state)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
@@ -943,11 +971,12 @@ class VocabParallelEmbedding1D(ParallelLayer):
|
||||
masked_input = input_.clone() - self.vocab_start_index
|
||||
masked_input[input_mask] = 0
|
||||
|
||||
output_parallel = F.embedding(masked_input, self.weight, self.padding_idx, *self.embed_args,
|
||||
**self.embed_kwargs)
|
||||
output_parallel = F.embedding(
|
||||
masked_input, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs
|
||||
)
|
||||
|
||||
# Mask the output embedding.
|
||||
output_parallel[input_mask, :] = 0.
|
||||
output_parallel[input_mask, :] = 0.0
|
||||
# Reduce across all the model parallel GPUs.
|
||||
output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D)
|
||||
return output
|
||||
@@ -1002,30 +1031,34 @@ class PatchEmbedding1D(ColossalaiModule):
|
||||
:type position_embed_initializer: typing.Callable, optional
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
img_size: int,
|
||||
patch_size: int,
|
||||
in_chans: int,
|
||||
embed_size: int,
|
||||
dtype: torch.dtype = None,
|
||||
flatten: bool = True,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
position_embed_initializer: Callable = init.zeros_()):
|
||||
embed = VanillaPatchEmbedding(img_size,
|
||||
patch_size,
|
||||
in_chans,
|
||||
embed_size,
|
||||
dtype=dtype,
|
||||
flatten=flatten,
|
||||
weight_initializer=weight_initializer,
|
||||
bias_initializer=bias_initializer,
|
||||
position_embed_initializer=position_embed_initializer)
|
||||
def __init__(
|
||||
self,
|
||||
img_size: int,
|
||||
patch_size: int,
|
||||
in_chans: int,
|
||||
embed_size: int,
|
||||
dtype: torch.dtype = None,
|
||||
flatten: bool = True,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
position_embed_initializer: Callable = init.zeros_(),
|
||||
):
|
||||
embed = VanillaPatchEmbedding(
|
||||
img_size,
|
||||
patch_size,
|
||||
in_chans,
|
||||
embed_size,
|
||||
dtype=dtype,
|
||||
flatten=flatten,
|
||||
weight_initializer=weight_initializer,
|
||||
bias_initializer=bias_initializer,
|
||||
position_embed_initializer=position_embed_initializer,
|
||||
)
|
||||
super().__init__(embed)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args):
|
||||
local_state = OrderedDict()
|
||||
param_keys = [prefix + 'weight', prefix + 'bias', prefix + 'cls_token', prefix + 'pos_embed']
|
||||
param_keys = [prefix + "weight", prefix + "bias", prefix + "cls_token", prefix + "pos_embed"]
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
for key in param_keys:
|
||||
param = state_dict.pop(key, None)
|
||||
|
Reference in New Issue
Block a user