mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 02:26:51 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -55,14 +55,16 @@ class Linear2D(ParallelLayer):
|
||||
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
dtype: torch.dtype = None,
|
||||
skip_bias_add: bool = False,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
dtype: torch.dtype = None,
|
||||
skip_bias_add: bool = False,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.in_features = in_features
|
||||
@@ -80,15 +82,16 @@ class Linear2D(ParallelLayer):
|
||||
self.hidden_size_per_partition = divide(self.out_features, self.summa_dim)
|
||||
|
||||
# create weight, shape: [k/q, h/q]
|
||||
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
|
||||
factory_kwargs = {"device": get_current_device(), "dtype": dtype}
|
||||
self.weight = Parameter(
|
||||
torch.empty(self.input_size_per_partition, self.hidden_size_per_partition, **factory_kwargs))
|
||||
torch.empty(self.input_size_per_partition, self.hidden_size_per_partition, **factory_kwargs)
|
||||
)
|
||||
|
||||
# create bias, shape: [h/q]
|
||||
if bias:
|
||||
self.bias = Parameter(torch.empty(divide(self.out_features, self.summa_dim**2), **factory_kwargs))
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
# initialize parameters
|
||||
with seed(ParallelMode.TENSOR):
|
||||
@@ -108,8 +111,8 @@ class Linear2D(ParallelLayer):
|
||||
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
weight_key = prefix + "weight"
|
||||
bias_key = prefix + "bias"
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
# weight
|
||||
weight = state_dict.pop(weight_key, None)
|
||||
@@ -126,34 +129,22 @@ class Linear2D(ParallelLayer):
|
||||
local_state = partition_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
ParallelMode.PARALLEL_2D_ROW,
|
||||
dims={
|
||||
weight_key: -1,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: True
|
||||
},
|
||||
dims={weight_key: -1, bias_key: 0},
|
||||
partition_states={weight_key: True, bias_key: True},
|
||||
)
|
||||
# partition in column groups
|
||||
local_state = partition_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
ParallelMode.PARALLEL_2D_COL,
|
||||
dims={
|
||||
weight_key: 0,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: True
|
||||
},
|
||||
dims={weight_key: 0, bias_key: 0},
|
||||
partition_states={weight_key: True, bias_key: True},
|
||||
)
|
||||
|
||||
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
|
||||
|
||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
weight_key = prefix + "weight"
|
||||
bias_key = prefix + "bias"
|
||||
local_state = OrderedDict({weight_key: self.weight})
|
||||
if self.bias is not None:
|
||||
local_state[bias_key] = self.bias
|
||||
@@ -162,14 +153,8 @@ class Linear2D(ParallelLayer):
|
||||
local_state = gather_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
ParallelMode.PARALLEL_2D_COL,
|
||||
dims={
|
||||
weight_key: 0,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: True
|
||||
},
|
||||
dims={weight_key: 0, bias_key: 0},
|
||||
partition_states={weight_key: True, bias_key: True},
|
||||
keep_vars=keep_vars,
|
||||
)
|
||||
# gather in row groups
|
||||
@@ -177,14 +162,8 @@ class Linear2D(ParallelLayer):
|
||||
local_state = gather_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
ParallelMode.PARALLEL_2D_ROW,
|
||||
dims={
|
||||
weight_key: -1,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: True
|
||||
},
|
||||
dims={weight_key: -1, bias_key: 0},
|
||||
partition_states={weight_key: True, bias_key: True},
|
||||
keep_vars=keep_vars,
|
||||
)
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
@@ -196,22 +175,53 @@ class Linear2D(ParallelLayer):
|
||||
# output: [m/q, n/q, h/q]
|
||||
out_shape = x.shape[:-1] + (self.hidden_size_per_partition,)
|
||||
|
||||
output = Matmul_AB_2D.apply(x, self.weight, self.summa_dim, out_shape, self.row_rank, self.col_rank,
|
||||
ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, self.data_parallel_rank,
|
||||
self.pipeline_parallel_rank, self.pipeline_parallel_size, self.tensor_parallel_size)
|
||||
output = Matmul_AB_2D.apply(
|
||||
x,
|
||||
self.weight,
|
||||
self.summa_dim,
|
||||
out_shape,
|
||||
self.row_rank,
|
||||
self.col_rank,
|
||||
ParallelMode.PARALLEL_2D_ROW,
|
||||
ParallelMode.PARALLEL_2D_COL,
|
||||
self.data_parallel_rank,
|
||||
self.pipeline_parallel_rank,
|
||||
self.pipeline_parallel_size,
|
||||
self.tensor_parallel_size,
|
||||
)
|
||||
|
||||
if self.bias is not None:
|
||||
if self.skip_bias_add:
|
||||
bias = add_bias_2d(None, self.bias, self.hidden_size_per_partition, self.row_rank, self.col_rank,
|
||||
ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, True,
|
||||
self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size,
|
||||
self.tensor_parallel_size)
|
||||
bias = add_bias_2d(
|
||||
None,
|
||||
self.bias,
|
||||
self.hidden_size_per_partition,
|
||||
self.row_rank,
|
||||
self.col_rank,
|
||||
ParallelMode.PARALLEL_2D_ROW,
|
||||
ParallelMode.PARALLEL_2D_COL,
|
||||
True,
|
||||
self.data_parallel_rank,
|
||||
self.pipeline_parallel_rank,
|
||||
self.pipeline_parallel_size,
|
||||
self.tensor_parallel_size,
|
||||
)
|
||||
return output, bias
|
||||
else:
|
||||
output = add_bias_2d(output, self.bias, self.hidden_size_per_partition, self.row_rank, self.col_rank,
|
||||
ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, False,
|
||||
self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size,
|
||||
self.tensor_parallel_size)
|
||||
output = add_bias_2d(
|
||||
output,
|
||||
self.bias,
|
||||
self.hidden_size_per_partition,
|
||||
self.row_rank,
|
||||
self.col_rank,
|
||||
ParallelMode.PARALLEL_2D_ROW,
|
||||
ParallelMode.PARALLEL_2D_COL,
|
||||
False,
|
||||
self.data_parallel_rank,
|
||||
self.pipeline_parallel_rank,
|
||||
self.pipeline_parallel_size,
|
||||
self.tensor_parallel_size,
|
||||
)
|
||||
return output
|
||||
else:
|
||||
return output
|
||||
@@ -249,7 +259,7 @@ class LayerNorm2D(ParallelLayer):
|
||||
self.partitioned_partition = divide(normalized_shape, self.summa_dim**2)
|
||||
|
||||
# create parameters
|
||||
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
|
||||
factory_kwargs = {"device": get_current_device(), "dtype": dtype}
|
||||
|
||||
self.weight = Parameter(torch.ones(self.partitioned_partition, **factory_kwargs))
|
||||
if bias:
|
||||
@@ -266,8 +276,8 @@ class LayerNorm2D(ParallelLayer):
|
||||
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
weight_key = prefix + "weight"
|
||||
bias_key = prefix + "bias"
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
# weight
|
||||
weight = state_dict.pop(weight_key, None)
|
||||
@@ -283,34 +293,22 @@ class LayerNorm2D(ParallelLayer):
|
||||
local_state = partition_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
ParallelMode.PARALLEL_2D_ROW,
|
||||
dims={
|
||||
weight_key: 0,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: True
|
||||
},
|
||||
dims={weight_key: 0, bias_key: 0},
|
||||
partition_states={weight_key: True, bias_key: True},
|
||||
)
|
||||
# partition in column groups
|
||||
local_state = partition_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
ParallelMode.PARALLEL_2D_COL,
|
||||
dims={
|
||||
weight_key: 0,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: True
|
||||
},
|
||||
dims={weight_key: 0, bias_key: 0},
|
||||
partition_states={weight_key: True, bias_key: True},
|
||||
)
|
||||
|
||||
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
|
||||
|
||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
weight_key = prefix + "weight"
|
||||
bias_key = prefix + "bias"
|
||||
local_state = OrderedDict({weight_key: self.weight})
|
||||
if self.bias is not None:
|
||||
local_state[bias_key] = self.bias
|
||||
@@ -319,14 +317,8 @@ class LayerNorm2D(ParallelLayer):
|
||||
local_state = gather_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
ParallelMode.PARALLEL_2D_COL,
|
||||
dims={
|
||||
weight_key: 0,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: True
|
||||
},
|
||||
dims={weight_key: 0, bias_key: 0},
|
||||
partition_states={weight_key: True, bias_key: True},
|
||||
keep_vars=keep_vars,
|
||||
)
|
||||
# gather in row groups
|
||||
@@ -334,14 +326,8 @@ class LayerNorm2D(ParallelLayer):
|
||||
local_state = gather_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
ParallelMode.PARALLEL_2D_ROW,
|
||||
dims={
|
||||
weight_key: 0,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: True
|
||||
},
|
||||
dims={weight_key: 0, bias_key: 0},
|
||||
partition_states={weight_key: True, bias_key: True},
|
||||
keep_vars=keep_vars,
|
||||
)
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
@@ -349,29 +335,51 @@ class LayerNorm2D(ParallelLayer):
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
with torch.no_grad():
|
||||
E_x = torch.sum(x, dim=-1, keepdim=True) # [b/q, s, 1]
|
||||
E_x = torch.sum(x, dim=-1, keepdim=True) # [b/q, s, 1]
|
||||
torch.distributed.all_reduce(E_x, group=gpc.get_group(ParallelMode.PARALLEL_2D_ROW))
|
||||
E_x /= self.normalized_shape
|
||||
|
||||
# Var_x in the block below is the sum of input^2
|
||||
Var_x = torch.sum(x * x, dim=-1, keepdim=True) # [b/q, s, 1]
|
||||
Var_x = torch.sum(x * x, dim=-1, keepdim=True) # [b/q, s, 1]
|
||||
torch.distributed.all_reduce(Var_x, group=gpc.get_group(ParallelMode.PARALLEL_2D_ROW))
|
||||
Var_x /= self.normalized_shape
|
||||
|
||||
Var_x = Var_x - E_x * E_x # variance of x [b/q, s, 1]
|
||||
Var_x = Var_x - E_x * E_x # variance of x [b/q, s, 1]
|
||||
# this time 1/sqrt(Var_x + epsilon)
|
||||
Var_x = 1.0 / torch.sqrt(Var_x + self.variance_epsilon)
|
||||
|
||||
output = layernorm_2d(x, E_x, Var_x, self.normalized_shape, ParallelMode.PARALLEL_2D_ROW,
|
||||
ParallelMode.PARALLEL_2D_COL)
|
||||
scale = add_bias_2d(None, self.weight, self.partitioned_partition, self.row_rank, self.col_rank,
|
||||
ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, True, self.data_parallel_rank,
|
||||
self.pipeline_parallel_rank, self.pipeline_parallel_size, self.tensor_parallel_size)
|
||||
output = layernorm_2d(
|
||||
x, E_x, Var_x, self.normalized_shape, ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL
|
||||
)
|
||||
scale = add_bias_2d(
|
||||
None,
|
||||
self.weight,
|
||||
self.partitioned_partition,
|
||||
self.row_rank,
|
||||
self.col_rank,
|
||||
ParallelMode.PARALLEL_2D_ROW,
|
||||
ParallelMode.PARALLEL_2D_COL,
|
||||
True,
|
||||
self.data_parallel_rank,
|
||||
self.pipeline_parallel_rank,
|
||||
self.pipeline_parallel_size,
|
||||
self.tensor_parallel_size,
|
||||
)
|
||||
if self.bias is not None:
|
||||
bias = add_bias_2d(None, self.bias, self.partitioned_partition, self.row_rank, self.col_rank,
|
||||
ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, True,
|
||||
self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size,
|
||||
self.tensor_parallel_size)
|
||||
bias = add_bias_2d(
|
||||
None,
|
||||
self.bias,
|
||||
self.partitioned_partition,
|
||||
self.row_rank,
|
||||
self.col_rank,
|
||||
ParallelMode.PARALLEL_2D_ROW,
|
||||
ParallelMode.PARALLEL_2D_COL,
|
||||
True,
|
||||
self.data_parallel_rank,
|
||||
self.pipeline_parallel_rank,
|
||||
self.pipeline_parallel_size,
|
||||
self.tensor_parallel_size,
|
||||
)
|
||||
output = torch.addcmul(bias, scale, output)
|
||||
else:
|
||||
output = torch.mul(scale, output)
|
||||
@@ -400,16 +408,18 @@ class PatchEmbedding2D(ParallelLayer):
|
||||
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
img_size: int,
|
||||
patch_size: int,
|
||||
in_chans: int,
|
||||
embed_size: int,
|
||||
flatten: bool = True,
|
||||
dtype: torch.dtype = None,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
position_embed_initializer: Callable = init.zeros_()):
|
||||
def __init__(
|
||||
self,
|
||||
img_size: int,
|
||||
patch_size: int,
|
||||
in_chans: int,
|
||||
embed_size: int,
|
||||
flatten: bool = True,
|
||||
dtype: torch.dtype = None,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
position_embed_initializer: Callable = init.zeros_(),
|
||||
):
|
||||
super().__init__()
|
||||
img_size = to_2tuple(img_size)
|
||||
patch_size = to_2tuple(patch_size)
|
||||
@@ -426,17 +436,22 @@ class PatchEmbedding2D(ParallelLayer):
|
||||
|
||||
with seed(ParallelMode.TENSOR):
|
||||
self.weight = Parameter(
|
||||
torch.empty((self.embed_size_per_partition, in_chans, *self.patch_size),
|
||||
device=get_current_device(),
|
||||
dtype=dtype))
|
||||
torch.empty(
|
||||
(self.embed_size_per_partition, in_chans, *self.patch_size),
|
||||
device=get_current_device(),
|
||||
dtype=dtype,
|
||||
)
|
||||
)
|
||||
self.bias = Parameter(torch.empty(self.embed_size_per_partition, device=get_current_device(), dtype=dtype))
|
||||
|
||||
self.cls_token = Parameter(
|
||||
torch.zeros((1, 1, self.embed_size_per_partition), device=get_current_device(), dtype=dtype))
|
||||
torch.zeros((1, 1, self.embed_size_per_partition), device=get_current_device(), dtype=dtype)
|
||||
)
|
||||
self.pos_embed = Parameter(
|
||||
torch.zeros((1, self.num_patches + 1, self.embed_size_per_partition),
|
||||
device=get_current_device(),
|
||||
dtype=dtype))
|
||||
torch.zeros(
|
||||
(1, self.num_patches + 1, self.embed_size_per_partition), device=get_current_device(), dtype=dtype
|
||||
)
|
||||
)
|
||||
|
||||
self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer)
|
||||
self._set_tensor_parallel_attribute()
|
||||
@@ -457,10 +472,10 @@ class PatchEmbedding2D(ParallelLayer):
|
||||
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
cls_token_key = prefix + 'cls_token'
|
||||
pos_embed_key = prefix + 'pos_embed'
|
||||
weight_key = prefix + "weight"
|
||||
bias_key = prefix + "bias"
|
||||
cls_token_key = prefix + "cls_token"
|
||||
pos_embed_key = prefix + "pos_embed"
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
# weight
|
||||
weight = state_dict.pop(weight_key, None)
|
||||
@@ -484,67 +499,34 @@ class PatchEmbedding2D(ParallelLayer):
|
||||
local_state = partition_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
ParallelMode.PARALLEL_2D_ROW,
|
||||
dims={
|
||||
weight_key: 0,
|
||||
bias_key: 0,
|
||||
cls_token_key: -1,
|
||||
pos_embed_key: -1
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: True,
|
||||
cls_token_key: True,
|
||||
pos_embed_key: True
|
||||
},
|
||||
dims={weight_key: 0, bias_key: 0, cls_token_key: -1, pos_embed_key: -1},
|
||||
partition_states={weight_key: True, bias_key: True, cls_token_key: True, pos_embed_key: True},
|
||||
)
|
||||
# partition in column groups
|
||||
local_state = partition_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
ParallelMode.PARALLEL_2D_COL,
|
||||
dims={
|
||||
weight_key: 0,
|
||||
bias_key: 0,
|
||||
cls_token_key: -1,
|
||||
pos_embed_key: -1
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: True,
|
||||
cls_token_key: True,
|
||||
pos_embed_key: True
|
||||
},
|
||||
dims={weight_key: 0, bias_key: 0, cls_token_key: -1, pos_embed_key: -1},
|
||||
partition_states={weight_key: True, bias_key: True, cls_token_key: True, pos_embed_key: True},
|
||||
)
|
||||
|
||||
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
|
||||
|
||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
cls_token_key = prefix + 'cls_token'
|
||||
pos_embed_key = prefix + 'pos_embed'
|
||||
local_state = OrderedDict({
|
||||
weight_key: self.weight,
|
||||
bias_key: self.bias,
|
||||
cls_token_key: self.cls_token,
|
||||
pos_embed_key: self.pos_embed
|
||||
})
|
||||
weight_key = prefix + "weight"
|
||||
bias_key = prefix + "bias"
|
||||
cls_token_key = prefix + "cls_token"
|
||||
pos_embed_key = prefix + "pos_embed"
|
||||
local_state = OrderedDict(
|
||||
{weight_key: self.weight, bias_key: self.bias, cls_token_key: self.cls_token, pos_embed_key: self.pos_embed}
|
||||
)
|
||||
|
||||
# gather in column groups
|
||||
local_state = gather_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
ParallelMode.PARALLEL_2D_COL,
|
||||
dims={
|
||||
weight_key: 0,
|
||||
bias_key: 0,
|
||||
cls_token_key: -1,
|
||||
pos_embed_key: -1
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: True,
|
||||
cls_token_key: True,
|
||||
pos_embed_key: True
|
||||
},
|
||||
dims={weight_key: 0, bias_key: 0, cls_token_key: -1, pos_embed_key: -1},
|
||||
partition_states={weight_key: True, bias_key: True, cls_token_key: True, pos_embed_key: True},
|
||||
keep_vars=keep_vars,
|
||||
)
|
||||
# gather in row groups
|
||||
@@ -552,18 +534,8 @@ class PatchEmbedding2D(ParallelLayer):
|
||||
local_state = gather_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
ParallelMode.PARALLEL_2D_ROW,
|
||||
dims={
|
||||
weight_key: 0,
|
||||
bias_key: 0,
|
||||
cls_token_key: -1,
|
||||
pos_embed_key: -1
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: True,
|
||||
cls_token_key: True,
|
||||
pos_embed_key: True
|
||||
},
|
||||
dims={weight_key: 0, bias_key: 0, cls_token_key: -1, pos_embed_key: -1},
|
||||
partition_states={weight_key: True, bias_key: True, cls_token_key: True, pos_embed_key: True},
|
||||
keep_vars=keep_vars,
|
||||
)
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
@@ -573,15 +545,16 @@ class PatchEmbedding2D(ParallelLayer):
|
||||
input_ = split_batch_2d(input_)
|
||||
|
||||
B, C, H, W = input_.shape
|
||||
assert H == self.img_size[0] and W == self.img_size[1], \
|
||||
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
||||
assert (
|
||||
H == self.img_size[0] and W == self.img_size[1]
|
||||
), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
||||
|
||||
weight = all_gather_tensor_2d(self.weight, 0, ParallelMode.PARALLEL_2D_COL)
|
||||
bias = all_gather_tensor_2d(self.bias, 0, ParallelMode.PARALLEL_2D_COL)
|
||||
|
||||
output = F.conv2d(input_, weight, bias, stride=self.patch_size)
|
||||
if self.flatten:
|
||||
output = output.flatten(2).transpose(1, 2) # BCHW -> BNC
|
||||
output = output.flatten(2).transpose(1, 2) # BCHW -> BNC
|
||||
|
||||
cls_token = all_gather_tensor_2d(self.cls_token, -1, ParallelMode.PARALLEL_2D_COL)
|
||||
pos_embed = all_gather_tensor_2d(self.pos_embed, -1, ParallelMode.PARALLEL_2D_COL)
|
||||
@@ -623,14 +596,16 @@ class Embedding2D(ParallelLayer):
|
||||
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
padding_idx: int = None,
|
||||
dtype: torch.dtype = None,
|
||||
weight_initializer: Callable = init.normal_(),
|
||||
*args,
|
||||
**kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
padding_idx: int = None,
|
||||
dtype: torch.dtype = None,
|
||||
weight_initializer: Callable = init.normal_(),
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
assert_summa_initialization()
|
||||
@@ -644,7 +619,8 @@ class Embedding2D(ParallelLayer):
|
||||
self.embed_kwargs = kwargs
|
||||
|
||||
self.weight = Parameter(
|
||||
torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype))
|
||||
torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype)
|
||||
)
|
||||
|
||||
self.reset_parameters(weight_initializer)
|
||||
self._set_tensor_parallel_attributes()
|
||||
@@ -665,7 +641,7 @@ class Embedding2D(ParallelLayer):
|
||||
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
weight_key = prefix + "weight"
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
# weight
|
||||
weight = state_dict.pop(weight_key, None)
|
||||
@@ -691,7 +667,7 @@ class Embedding2D(ParallelLayer):
|
||||
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
|
||||
|
||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
weight_key = prefix + "weight"
|
||||
local_state = OrderedDict({weight_key: self.weight})
|
||||
|
||||
# gather in column groups
|
||||
@@ -754,14 +730,16 @@ class VocabParallelEmbedding2D(ParallelLayer):
|
||||
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
padding_idx: int = None,
|
||||
dtype: torch.dtype = None,
|
||||
weight_initializer: Callable = init.normal_(),
|
||||
*args,
|
||||
**kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
padding_idx: int = None,
|
||||
dtype: torch.dtype = None,
|
||||
weight_initializer: Callable = init.normal_(),
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_embeddings = num_embeddings
|
||||
self.embed_dim = embedding_dim
|
||||
@@ -778,9 +756,12 @@ class VocabParallelEmbedding2D(ParallelLayer):
|
||||
self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition
|
||||
|
||||
self.weight = Parameter(
|
||||
torch.empty((self.num_embeddings_per_partition, self.embed_dim_per_partition),
|
||||
device=get_current_device(),
|
||||
dtype=dtype))
|
||||
torch.empty(
|
||||
(self.num_embeddings_per_partition, self.embed_dim_per_partition),
|
||||
device=get_current_device(),
|
||||
dtype=dtype,
|
||||
)
|
||||
)
|
||||
|
||||
self.reset_parameters(weight_initializer)
|
||||
self._set_tensor_parallel_attributes()
|
||||
@@ -796,14 +777,17 @@ class VocabParallelEmbedding2D(ParallelLayer):
|
||||
self._fill_padding_idx_with_zero()
|
||||
|
||||
def _fill_padding_idx_with_zero(self) -> None:
|
||||
if self.padding_idx is not None and \
|
||||
self.padding_idx >= self.vocab_start_index and self.padding_idx < self.vocab_end_index:
|
||||
if (
|
||||
self.padding_idx is not None
|
||||
and self.padding_idx >= self.vocab_start_index
|
||||
and self.padding_idx < self.vocab_end_index
|
||||
):
|
||||
with torch.no_grad():
|
||||
self.weight[self.padding_idx - self.vocab_start_index].fill_(0)
|
||||
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
weight_key = prefix + "weight"
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
# weight
|
||||
weight = state_dict.pop(weight_key, None)
|
||||
@@ -829,7 +813,7 @@ class VocabParallelEmbedding2D(ParallelLayer):
|
||||
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
|
||||
|
||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
weight_key = prefix + "weight"
|
||||
local_state = OrderedDict({weight_key: self.weight})
|
||||
|
||||
# gather in column groups
|
||||
@@ -857,10 +841,11 @@ class VocabParallelEmbedding2D(ParallelLayer):
|
||||
masked_input = input_.clone() - self.vocab_start_index
|
||||
masked_input[input_mask] = 0
|
||||
|
||||
output_parallel = F.embedding(masked_input, self.weight, self.padding_idx, *self.embed_args,
|
||||
**self.embed_kwargs)
|
||||
output_parallel = F.embedding(
|
||||
masked_input, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs
|
||||
)
|
||||
|
||||
output_parallel[input_mask, :] = 0.
|
||||
output_parallel[input_mask, :] = 0.0
|
||||
output = reduce_scatter_tensor_2d(output_parallel, 0, ParallelMode.PARALLEL_2D_COL)
|
||||
return output
|
||||
|
||||
@@ -884,14 +869,16 @@ class Classifier2D(ParallelLayer):
|
||||
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_features: int,
|
||||
num_classes: int,
|
||||
weight: Parameter = None,
|
||||
bias: bool = True,
|
||||
dtype: torch.dtype = None,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
num_classes: int,
|
||||
weight: Parameter = None,
|
||||
bias: bool = True,
|
||||
dtype: torch.dtype = None,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
):
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
self.num_classes = num_classes
|
||||
@@ -908,7 +895,8 @@ class Classifier2D(ParallelLayer):
|
||||
self.has_weight = False
|
||||
else:
|
||||
self.weight = Parameter(
|
||||
torch.empty(self.num_classes, self.input_size_per_partition, device=get_current_device(), dtype=dtype))
|
||||
torch.empty(self.num_classes, self.input_size_per_partition, device=get_current_device(), dtype=dtype)
|
||||
)
|
||||
self.has_weight = True
|
||||
if bias:
|
||||
self.bias = Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype))
|
||||
@@ -938,8 +926,8 @@ class Classifier2D(ParallelLayer):
|
||||
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
weight_key = prefix + "weight"
|
||||
bias_key = prefix + "bias"
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
# weight
|
||||
if self.has_weight:
|
||||
@@ -957,34 +945,22 @@ class Classifier2D(ParallelLayer):
|
||||
local_state = partition_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
ParallelMode.PARALLEL_2D_ROW,
|
||||
dims={
|
||||
weight_key: -1,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: False
|
||||
},
|
||||
dims={weight_key: -1, bias_key: 0},
|
||||
partition_states={weight_key: True, bias_key: False},
|
||||
)
|
||||
# partition in column groups
|
||||
local_state = partition_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
ParallelMode.PARALLEL_2D_COL,
|
||||
dims={
|
||||
weight_key: -1,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: False
|
||||
},
|
||||
dims={weight_key: -1, bias_key: 0},
|
||||
partition_states={weight_key: True, bias_key: False},
|
||||
)
|
||||
|
||||
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
|
||||
|
||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
weight_key = prefix + "weight"
|
||||
bias_key = prefix + "bias"
|
||||
local_state = OrderedDict()
|
||||
if self.has_weight:
|
||||
local_state[weight_key] = self.weight
|
||||
@@ -995,14 +971,8 @@ class Classifier2D(ParallelLayer):
|
||||
local_state = gather_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
ParallelMode.PARALLEL_2D_COL,
|
||||
dims={
|
||||
weight_key: -1,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: False
|
||||
},
|
||||
dims={weight_key: -1, bias_key: 0},
|
||||
partition_states={weight_key: True, bias_key: False},
|
||||
keep_vars=keep_vars,
|
||||
)
|
||||
# gather in row groups
|
||||
@@ -1010,14 +980,8 @@ class Classifier2D(ParallelLayer):
|
||||
local_state = gather_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
ParallelMode.PARALLEL_2D_ROW,
|
||||
dims={
|
||||
weight_key: -1,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: False
|
||||
},
|
||||
dims={weight_key: -1, bias_key: 0},
|
||||
partition_states={weight_key: True, bias_key: False},
|
||||
keep_vars=keep_vars,
|
||||
)
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
@@ -1026,9 +990,21 @@ class Classifier2D(ParallelLayer):
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
out_shape = input_.shape[:-1] + (self.num_classes,)
|
||||
|
||||
return classifier_2d(input_, self.weight, self.bias, self.summa_dim, out_shape, self.row_rank, self.col_rank,
|
||||
ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, self.data_parallel_rank,
|
||||
self.pipeline_parallel_rank, self.pipeline_parallel_size, self.tensor_parallel_size)
|
||||
return classifier_2d(
|
||||
input_,
|
||||
self.weight,
|
||||
self.bias,
|
||||
self.summa_dim,
|
||||
out_shape,
|
||||
self.row_rank,
|
||||
self.col_rank,
|
||||
ParallelMode.PARALLEL_2D_ROW,
|
||||
ParallelMode.PARALLEL_2D_COL,
|
||||
self.data_parallel_rank,
|
||||
self.pipeline_parallel_rank,
|
||||
self.pipeline_parallel_size,
|
||||
self.tensor_parallel_size,
|
||||
)
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
@@ -1050,14 +1026,16 @@ class VocabParallelClassifier2D(ParallelLayer):
|
||||
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_features: int,
|
||||
num_classes: int,
|
||||
weight: Parameter = None,
|
||||
bias: bool = True,
|
||||
dtype: torch.dtype = None,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
num_classes: int,
|
||||
weight: Parameter = None,
|
||||
bias: bool = True,
|
||||
dtype: torch.dtype = None,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.in_features = in_features
|
||||
@@ -1074,13 +1052,14 @@ class VocabParallelClassifier2D(ParallelLayer):
|
||||
self.output_size_per_partition = divide(num_classes, self.summa_dim)
|
||||
|
||||
# create weight, shape: [k/q, h/q]
|
||||
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
|
||||
factory_kwargs = {"device": get_current_device(), "dtype": dtype}
|
||||
if weight is not None:
|
||||
self.weight = weight
|
||||
self.has_weight = False
|
||||
else:
|
||||
self.weight = Parameter(
|
||||
torch.empty(self.output_size_per_partition, self.input_size_per_partition, **factory_kwargs))
|
||||
torch.empty(self.output_size_per_partition, self.input_size_per_partition, **factory_kwargs)
|
||||
)
|
||||
self.has_weight = True
|
||||
# create bias, shape: [h/q]
|
||||
if bias:
|
||||
@@ -1109,8 +1088,8 @@ class VocabParallelClassifier2D(ParallelLayer):
|
||||
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
weight_key = prefix + "weight"
|
||||
bias_key = prefix + "bias"
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
# weight
|
||||
if self.has_weight:
|
||||
@@ -1128,34 +1107,22 @@ class VocabParallelClassifier2D(ParallelLayer):
|
||||
local_state = partition_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
ParallelMode.PARALLEL_2D_ROW,
|
||||
dims={
|
||||
weight_key: -1,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: True
|
||||
},
|
||||
dims={weight_key: -1, bias_key: 0},
|
||||
partition_states={weight_key: True, bias_key: True},
|
||||
)
|
||||
# partition in column groups
|
||||
local_state = partition_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
ParallelMode.PARALLEL_2D_COL,
|
||||
dims={
|
||||
weight_key: 0,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: True
|
||||
},
|
||||
dims={weight_key: 0, bias_key: 0},
|
||||
partition_states={weight_key: True, bias_key: True},
|
||||
)
|
||||
|
||||
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
|
||||
|
||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
weight_key = prefix + "weight"
|
||||
bias_key = prefix + "bias"
|
||||
local_state = OrderedDict()
|
||||
if self.has_weight:
|
||||
local_state[weight_key] = self.weight
|
||||
@@ -1166,14 +1133,8 @@ class VocabParallelClassifier2D(ParallelLayer):
|
||||
local_state = gather_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
ParallelMode.PARALLEL_2D_COL,
|
||||
dims={
|
||||
weight_key: 0,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: True
|
||||
},
|
||||
dims={weight_key: 0, bias_key: 0},
|
||||
partition_states={weight_key: True, bias_key: True},
|
||||
keep_vars=keep_vars,
|
||||
)
|
||||
# gather in row groups
|
||||
@@ -1181,14 +1142,8 @@ class VocabParallelClassifier2D(ParallelLayer):
|
||||
local_state = gather_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
ParallelMode.PARALLEL_2D_ROW,
|
||||
dims={
|
||||
weight_key: -1,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: True
|
||||
},
|
||||
dims={weight_key: -1, bias_key: 0},
|
||||
partition_states={weight_key: True, bias_key: True},
|
||||
keep_vars=keep_vars,
|
||||
)
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
@@ -1200,14 +1155,34 @@ class VocabParallelClassifier2D(ParallelLayer):
|
||||
# output: [m/q, n/q, h/q]
|
||||
out_shape = x.shape[:-1] + (self.output_size_per_partition,)
|
||||
|
||||
output = Matmul_ABT_2D.apply(x, self.weight, self.summa_dim, out_shape, self.row_rank, self.col_rank,
|
||||
ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL,
|
||||
self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size,
|
||||
self.tensor_parallel_size)
|
||||
output = Matmul_ABT_2D.apply(
|
||||
x,
|
||||
self.weight,
|
||||
self.summa_dim,
|
||||
out_shape,
|
||||
self.row_rank,
|
||||
self.col_rank,
|
||||
ParallelMode.PARALLEL_2D_ROW,
|
||||
ParallelMode.PARALLEL_2D_COL,
|
||||
self.data_parallel_rank,
|
||||
self.pipeline_parallel_rank,
|
||||
self.pipeline_parallel_size,
|
||||
self.tensor_parallel_size,
|
||||
)
|
||||
|
||||
if self.bias is not None:
|
||||
output = add_bias_2d(output, self.bias, self.output_size_per_partition, self.row_rank, self.col_rank,
|
||||
ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, False,
|
||||
self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size,
|
||||
self.tensor_parallel_size)
|
||||
output = add_bias_2d(
|
||||
output,
|
||||
self.bias,
|
||||
self.output_size_per_partition,
|
||||
self.row_rank,
|
||||
self.col_rank,
|
||||
ParallelMode.PARALLEL_2D_ROW,
|
||||
ParallelMode.PARALLEL_2D_COL,
|
||||
False,
|
||||
self.data_parallel_rank,
|
||||
self.pipeline_parallel_rank,
|
||||
self.pipeline_parallel_size,
|
||||
self.tensor_parallel_size,
|
||||
)
|
||||
return output
|
||||
|
Reference in New Issue
Block a user