[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -59,7 +59,6 @@ class LayerNorm3D(ParallelLayer):
"""
def __init__(self, normalized_shape: int, eps: float = 1e-12, bias=True, dtype=None):
super().__init__()
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
@@ -70,10 +69,12 @@ class LayerNorm3D(ParallelLayer):
self.normalized_shape_per_partition = divide(normalized_shape, self.depth)
self.weight = Parameter(
torch.ones(self.normalized_shape_per_partition, device=get_current_device(), dtype=dtype))
torch.ones(self.normalized_shape_per_partition, device=get_current_device(), dtype=dtype)
)
if bias:
self.bias = Parameter(
torch.zeros(self.normalized_shape_per_partition, device=get_current_device(), dtype=dtype))
torch.zeros(self.normalized_shape_per_partition, device=get_current_device(), dtype=dtype)
)
else:
self.bias = None
self.variance_epsilon = eps
@@ -94,8 +95,8 @@ class LayerNorm3D(ParallelLayer):
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict()
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
weight_key = prefix + "weight"
bias_key = prefix + "bias"
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
# weight
weight = state_dict.pop(weight_key, None)
@@ -107,15 +108,11 @@ class LayerNorm3D(ParallelLayer):
local_state[bias_key] = bias
# partition in output groups
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
gpc.get_local_rank(self.weight_parallel_mode) == 0:
if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0:
local_state = partition_tensor_parallel_state_dict(
local_state,
self.output_parallel_mode,
dims={
weight_key: 0,
bias_key: 0
},
dims={weight_key: 0, bias_key: 0},
partition_states={
weight_key: True,
bias_key: True,
@@ -130,26 +127,19 @@ class LayerNorm3D(ParallelLayer):
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
weight_key = prefix + "weight"
bias_key = prefix + "bias"
local_state = OrderedDict({weight_key: self.weight})
if self.bias is not None:
local_state[bias_key] = self.bias
# gather in output groups
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
gpc.get_local_rank(self.weight_parallel_mode) == 0:
if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0:
local_state = gather_tensor_parallel_state_dict(
local_state,
self.output_parallel_mode,
dims={
weight_key: 0,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: True
},
dims={weight_key: 0, bias_key: 0},
partition_states={weight_key: True, bias_key: True},
keep_vars=keep_vars,
)
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
@@ -185,14 +175,16 @@ class Linear3D(ParallelLayer):
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
"""
def __init__(self,
in_features: int,
out_features: int,
bias: bool = True,
dtype: torch.dtype = None,
skip_bias_add: bool = False,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
dtype: torch.dtype = None,
skip_bias_add: bool = False,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
):
super().__init__()
self.in_features = in_features
self.out_features = out_features
@@ -207,13 +199,17 @@ class Linear3D(ParallelLayer):
self.bias_features_per_partition = divide(out_features, self.depth)
self.weight = Parameter(
torch.empty(self.in_features_per_partition,
self.out_features_per_partition,
device=get_current_device(),
dtype=dtype))
torch.empty(
self.in_features_per_partition,
self.out_features_per_partition,
device=get_current_device(),
dtype=dtype,
)
)
if bias:
self.bias = Parameter(
torch.zeros(self.bias_features_per_partition, device=get_current_device(), dtype=dtype))
torch.zeros(self.bias_features_per_partition, device=get_current_device(), dtype=dtype)
)
else:
self.bias = None
@@ -239,15 +235,17 @@ class Linear3D(ParallelLayer):
if self.bias is not None:
bias_initializer(self.bias, fan_in=fan_in)
broadcast(self.bias,
gpc.get_ranks_in_group(self.output_x_weight_parallel_mode)[0],
self.output_x_weight_parallel_mode)
broadcast(
self.bias,
gpc.get_ranks_in_group(self.output_x_weight_parallel_mode)[0],
self.output_x_weight_parallel_mode,
)
self.bias.register_hook(self._sync_grad_hook)
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict()
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
weight_key = prefix + "weight"
bias_key = prefix + "bias"
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
# weight
weight = state_dict.pop(weight_key, None)
@@ -260,53 +258,34 @@ class Linear3D(ParallelLayer):
local_state[bias_key] = bias
# partition in output groups
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
gpc.get_local_rank(self.weight_parallel_mode) == 0:
if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0:
local_state = partition_tensor_parallel_state_dict(
local_state,
self.output_parallel_mode,
dims={
weight_key: 0,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: False
},
dims={weight_key: 0, bias_key: 0},
partition_states={weight_key: True, bias_key: False},
)
# partition in input groups
if gpc.get_local_rank(self.weight_parallel_mode) == 0:
local_state = partition_tensor_parallel_state_dict(
local_state,
self.input_parallel_mode,
dims={
weight_key: -1,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: True
},
dims={weight_key: -1, bias_key: 0},
partition_states={weight_key: True, bias_key: True},
)
# partition in weight groups
local_state = partition_tensor_parallel_state_dict(
local_state,
self.weight_parallel_mode,
dims={
weight_key: 0,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: False
},
dims={weight_key: 0, bias_key: 0},
partition_states={weight_key: True, bias_key: False},
)
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
weight_key = prefix + "weight"
bias_key = prefix + "bias"
local_state = OrderedDict({weight_key: self.weight})
if self.bias is not None:
local_state[bias_key] = self.bias
@@ -315,14 +294,8 @@ class Linear3D(ParallelLayer):
local_state = gather_tensor_parallel_state_dict(
local_state,
self.weight_parallel_mode,
dims={
weight_key: 0,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: False
},
dims={weight_key: 0, bias_key: 0},
partition_states={weight_key: True, bias_key: False},
keep_vars=keep_vars,
)
# gather in input groups
@@ -330,30 +303,17 @@ class Linear3D(ParallelLayer):
local_state = gather_tensor_parallel_state_dict(
local_state,
self.input_parallel_mode,
dims={
weight_key: -1,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: True
},
dims={weight_key: -1, bias_key: 0},
partition_states={weight_key: True, bias_key: True},
keep_vars=keep_vars,
)
# gather in output groups
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
gpc.get_local_rank(self.weight_parallel_mode) == 0:
if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0:
local_state = gather_tensor_parallel_state_dict(
local_state,
self.output_parallel_mode,
dims={
weight_key: 0,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: False
},
dims={weight_key: 0, bias_key: 0},
partition_states={weight_key: True, bias_key: False},
keep_vars=keep_vars,
)
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
@@ -396,14 +356,16 @@ class Classifier3D(ParallelLayer):
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
"""
def __init__(self,
in_features: int,
num_classes: int,
weight: Parameter = None,
bias: bool = True,
dtype: torch.dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
def __init__(
self,
in_features: int,
num_classes: int,
weight: Parameter = None,
bias: bool = True,
dtype: torch.dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
):
super().__init__()
self.in_features = in_features
self.num_classes = num_classes
@@ -418,7 +380,8 @@ class Classifier3D(ParallelLayer):
self.has_weight = False
else:
self.weight = Parameter(
torch.empty(self.num_classes, self.in_features_per_partition, device=get_current_device(), dtype=dtype))
torch.empty(self.num_classes, self.in_features_per_partition, device=get_current_device(), dtype=dtype)
)
self.has_weight = True
if bias:
self.bias = Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype))
@@ -449,8 +412,8 @@ class Classifier3D(ParallelLayer):
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict()
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
weight_key = prefix + "weight"
bias_key = prefix + "bias"
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
# weight
if self.has_weight:
@@ -464,19 +427,12 @@ class Classifier3D(ParallelLayer):
local_state[bias_key] = bias
# partition in output groups
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
gpc.get_local_rank(self.weight_parallel_mode) == 0:
if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0:
local_state = partition_tensor_parallel_state_dict(
local_state,
self.output_parallel_mode,
dims={
weight_key: -1,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: False
},
dims={weight_key: -1, bias_key: 0},
partition_states={weight_key: True, bias_key: False},
)
# broadcast in input groups
if gpc.get_local_rank(self.weight_parallel_mode) == 0:
@@ -487,8 +443,8 @@ class Classifier3D(ParallelLayer):
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
weight_key = prefix + "weight"
bias_key = prefix + "bias"
local_state = OrderedDict()
if self.has_weight:
local_state[weight_key] = self.weight
@@ -496,19 +452,12 @@ class Classifier3D(ParallelLayer):
local_state[bias_key] = self.bias
# gather in output groups
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
gpc.get_local_rank(self.weight_parallel_mode) == 0:
if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0:
local_state = gather_tensor_parallel_state_dict(
local_state,
self.output_parallel_mode,
dims={
weight_key: -1,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: False
},
dims={weight_key: -1, bias_key: 0},
partition_states={weight_key: True, bias_key: False},
keep_vars=keep_vars,
)
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
@@ -544,14 +493,16 @@ class VocabParallelClassifier3D(ParallelLayer):
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
"""
def __init__(self,
in_features: int,
num_classes: int,
weight: Parameter = None,
bias: bool = True,
dtype: torch.dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
def __init__(
self,
in_features: int,
num_classes: int,
weight: Parameter = None,
bias: bool = True,
dtype: torch.dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
):
super().__init__()
self.in_features = in_features
self.num_classes = num_classes
@@ -569,14 +520,18 @@ class VocabParallelClassifier3D(ParallelLayer):
self.has_weight = False
else:
self.weight = Parameter(
torch.empty(self.out_features_per_partition,
self.in_features_per_partition,
device=get_current_device(),
dtype=dtype))
torch.empty(
self.out_features_per_partition,
self.in_features_per_partition,
device=get_current_device(),
dtype=dtype,
)
)
self.has_weight = True
if bias:
self.bias = Parameter(
torch.zeros(self.bias_features_per_partition, device=get_current_device(), dtype=dtype))
torch.zeros(self.bias_features_per_partition, device=get_current_device(), dtype=dtype)
)
else:
self.bias = None
@@ -602,15 +557,17 @@ class VocabParallelClassifier3D(ParallelLayer):
if self.bias is not None:
bias_initializer(self.bias, fan_in=fan_in)
broadcast(self.bias,
gpc.get_ranks_in_group(self.output_x_weight_parallel_mode)[0],
self.output_x_weight_parallel_mode)
broadcast(
self.bias,
gpc.get_ranks_in_group(self.output_x_weight_parallel_mode)[0],
self.output_x_weight_parallel_mode,
)
register_async_grad_hook(self.bias)
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict()
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
weight_key = prefix + "weight"
bias_key = prefix + "bias"
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
# weight
if self.has_weight:
@@ -624,53 +581,34 @@ class VocabParallelClassifier3D(ParallelLayer):
local_state[bias_key] = bias
# partition in output groups
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
gpc.get_local_rank(self.weight_parallel_mode) == 0:
if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0:
local_state = partition_tensor_parallel_state_dict(
local_state,
self.output_parallel_mode,
dims={
weight_key: -1,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: False
},
dims={weight_key: -1, bias_key: 0},
partition_states={weight_key: True, bias_key: False},
)
# partition in input groups
if gpc.get_local_rank(self.weight_parallel_mode) == 0:
local_state = partition_tensor_parallel_state_dict(
local_state,
self.input_parallel_mode,
dims={
weight_key: 0,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: True
},
dims={weight_key: 0, bias_key: 0},
partition_states={weight_key: True, bias_key: True},
)
# partition in weight groups
local_state = partition_tensor_parallel_state_dict(
local_state,
self.weight_parallel_mode,
dims={
weight_key: 0,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: False
},
dims={weight_key: 0, bias_key: 0},
partition_states={weight_key: True, bias_key: False},
)
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
weight_key = prefix + "weight"
bias_key = prefix + "bias"
local_state = OrderedDict({weight_key: self.weight})
if self.bias is not None:
local_state[bias_key] = self.bias
@@ -679,14 +617,8 @@ class VocabParallelClassifier3D(ParallelLayer):
local_state = gather_tensor_parallel_state_dict(
local_state,
self.weight_parallel_mode,
dims={
weight_key: 0,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: False
},
dims={weight_key: 0, bias_key: 0},
partition_states={weight_key: True, bias_key: False},
keep_vars=keep_vars,
)
# gather in input groups
@@ -694,30 +626,17 @@ class VocabParallelClassifier3D(ParallelLayer):
local_state = gather_tensor_parallel_state_dict(
local_state,
self.input_parallel_mode,
dims={
weight_key: 0,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: True
},
dims={weight_key: 0, bias_key: 0},
partition_states={weight_key: True, bias_key: True},
keep_vars=keep_vars,
)
# gather in output groups
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
gpc.get_local_rank(self.weight_parallel_mode) == 0:
if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0:
local_state = gather_tensor_parallel_state_dict(
local_state,
self.output_parallel_mode,
dims={
weight_key: -1,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: False
},
dims={weight_key: -1, bias_key: 0},
partition_states={weight_key: True, bias_key: False},
keep_vars=keep_vars,
)
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
@@ -756,16 +675,18 @@ class PatchEmbedding3D(ParallelLayer):
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
"""
def __init__(self,
img_size: int,
patch_size: int,
in_chans: int,
embed_size: int,
flatten: bool = True,
dtype: torch.dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
position_embed_initializer: Callable = init.zeros_()):
def __init__(
self,
img_size: int,
patch_size: int,
in_chans: int,
embed_size: int,
flatten: bool = True,
dtype: torch.dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
position_embed_initializer: Callable = init.zeros_(),
):
super().__init__()
self.depth = get_depth_from_env()
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
@@ -783,15 +704,18 @@ class PatchEmbedding3D(ParallelLayer):
self.flatten = flatten
self.weight = nn.Parameter(
torch.empty((embed_size_per_partition, in_chans, *self.patch_size),
device=get_current_device(),
dtype=dtype))
torch.empty(
(embed_size_per_partition, in_chans, *self.patch_size), device=get_current_device(), dtype=dtype
)
)
self.bias = nn.Parameter(torch.empty(embed_size_per_partition, device=get_current_device(), dtype=dtype))
self.cls_token = nn.Parameter(
torch.zeros((1, 1, embed_size_per_partition), device=get_current_device(), dtype=dtype))
torch.zeros((1, 1, embed_size_per_partition), device=get_current_device(), dtype=dtype)
)
self.pos_embed = nn.Parameter(
torch.zeros((1, self.num_patches + 1, embed_size_per_partition), device=get_current_device(), dtype=dtype))
torch.zeros((1, self.num_patches + 1, embed_size_per_partition), device=get_current_device(), dtype=dtype)
)
self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer)
self._set_tensor_parallel_attributes()
@@ -826,10 +750,10 @@ class PatchEmbedding3D(ParallelLayer):
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict()
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
cls_token_key = prefix + 'cls_token'
pos_embed_key = prefix + 'pos_embed'
weight_key = prefix + "weight"
bias_key = prefix + "bias"
cls_token_key = prefix + "cls_token"
pos_embed_key = prefix + "pos_embed"
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
# weight
weight = state_dict.pop(weight_key, None)
@@ -849,23 +773,12 @@ class PatchEmbedding3D(ParallelLayer):
local_state[pos_embed_key] = pos_embed
# partition in output groups
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
gpc.get_local_rank(self.weight_parallel_mode) == 0:
if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0:
local_state = partition_tensor_parallel_state_dict(
local_state,
self.output_parallel_mode,
dims={
weight_key: 0,
bias_key: 0,
cls_token_key: -1,
pos_embed_key: -1
},
partition_states={
weight_key: True,
bias_key: True,
cls_token_key: True,
pos_embed_key: True
},
dims={weight_key: 0, bias_key: 0, cls_token_key: -1, pos_embed_key: -1},
partition_states={weight_key: True, bias_key: True, cls_token_key: True, pos_embed_key: True},
)
# broadcast in input groups
if gpc.get_local_rank(self.weight_parallel_mode) == 0:
@@ -876,47 +789,33 @@ class PatchEmbedding3D(ParallelLayer):
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
cls_token_key = prefix + 'cls_token'
pos_embed_key = prefix + 'pos_embed'
local_state = OrderedDict({
weight_key: self.weight,
bias_key: self.bias,
cls_token_key: self.cls_token,
pos_embed_key: self.pos_embed
})
weight_key = prefix + "weight"
bias_key = prefix + "bias"
cls_token_key = prefix + "cls_token"
pos_embed_key = prefix + "pos_embed"
local_state = OrderedDict(
{weight_key: self.weight, bias_key: self.bias, cls_token_key: self.cls_token, pos_embed_key: self.pos_embed}
)
# gather in output groups
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
gpc.get_local_rank(self.weight_parallel_mode) == 0:
if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0:
local_state = gather_tensor_parallel_state_dict(
local_state,
self.output_parallel_mode,
dims={
weight_key: 0,
bias_key: 0,
cls_token_key: -1,
pos_embed_key: -1
},
partition_states={
weight_key: True,
bias_key: True,
cls_token_key: True,
pos_embed_key: True
},
dims={weight_key: 0, bias_key: 0, cls_token_key: -1, pos_embed_key: -1},
partition_states={weight_key: True, bias_key: True, cls_token_key: True, pos_embed_key: True},
keep_vars=keep_vars,
)
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
destination.update(local_state)
def forward(self, input_: Tensor) -> Tensor:
input_ = split_batch_3d(input_,
input_parallel_mode=self.input_parallel_mode,
weight_parallel_mode=self.weight_parallel_mode)
input_ = split_batch_3d(
input_, input_parallel_mode=self.input_parallel_mode, weight_parallel_mode=self.weight_parallel_mode
)
output = F.conv2d(input_, self.weight, self.bias, stride=self.patch_size)
if self.flatten:
output = output.flatten(2).transpose(1, 2) # BCHW -> BNC
output = output.flatten(2).transpose(1, 2) # BCHW -> BNC
cls_token = self.cls_token.expand(output.shape[0], -1, -1)
output = torch.cat((cls_token, output), dim=1)
@@ -956,14 +855,16 @@ class Embedding3D(ParallelLayer):
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_
"""
def __init__(self,
num_embeddings: int,
embedding_dim: int,
padding_idx: int = None,
dtype: torch.dtype = None,
weight_initializer: Callable = init.normal_(),
*args,
**kwargs):
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
padding_idx: int = None,
dtype: torch.dtype = None,
weight_initializer: Callable = init.normal_(),
*args,
**kwargs,
):
super().__init__()
self.depth = get_depth_from_env()
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
@@ -979,7 +880,8 @@ class Embedding3D(ParallelLayer):
self.embed_kwargs = kwargs
self.weight = nn.Parameter(
torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype))
torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype)
)
self.reset_parameters(weight_initializer)
self._set_tensor_parallel_attributes()
@@ -996,8 +898,9 @@ class Embedding3D(ParallelLayer):
fan_in, fan_out = self.num_embeddings, self.embed_dim
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
self._fill_padding_idx_with_zero()
broadcast(self.weight,
gpc.get_ranks_in_group(self.input_x_weight_parallel_mode)[0], self.input_x_weight_parallel_mode)
broadcast(
self.weight, gpc.get_ranks_in_group(self.input_x_weight_parallel_mode)[0], self.input_x_weight_parallel_mode
)
self.weight.register_hook(self._sync_grad_hook)
def _fill_padding_idx_with_zero(self) -> None:
@@ -1007,7 +910,7 @@ class Embedding3D(ParallelLayer):
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict()
weight_key = prefix + 'weight'
weight_key = prefix + "weight"
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
# weight
weight = state_dict.pop(weight_key, None)
@@ -1015,8 +918,7 @@ class Embedding3D(ParallelLayer):
local_state[weight_key] = weight
# partition in output groups
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
gpc.get_local_rank(self.weight_parallel_mode) == 0:
if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0:
local_state = partition_tensor_parallel_state_dict(
local_state,
self.output_parallel_mode,
@@ -1032,12 +934,11 @@ class Embedding3D(ParallelLayer):
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight'
weight_key = prefix + "weight"
local_state = OrderedDict({weight_key: self.weight})
# gather in output groups
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
gpc.get_local_rank(self.weight_parallel_mode) == 0:
if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0:
local_state = gather_tensor_parallel_state_dict(
local_state,
self.output_parallel_mode,
@@ -1049,9 +950,9 @@ class Embedding3D(ParallelLayer):
destination.update(local_state)
def forward(self, input_: Tensor) -> Tensor:
input_ = split_batch_3d(input_,
input_parallel_mode=self.input_parallel_mode,
weight_parallel_mode=self.weight_parallel_mode)
input_ = split_batch_3d(
input_, input_parallel_mode=self.input_parallel_mode, weight_parallel_mode=self.weight_parallel_mode
)
output = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
return output
@@ -1088,14 +989,16 @@ class VocabParallelEmbedding3D(ParallelLayer):
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
"""
def __init__(self,
num_embeddings: int,
embedding_dim: int,
padding_idx: int = None,
dtype: torch.dtype = None,
weight_initializer: Callable = init.normal_(),
*args,
**kwargs):
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
padding_idx: int = None,
dtype: torch.dtype = None,
weight_initializer: Callable = init.normal_(),
*args,
**kwargs,
):
super().__init__()
self.num_embeddings = num_embeddings
self.embed_dim = embedding_dim
@@ -1114,9 +1017,12 @@ class VocabParallelEmbedding3D(ParallelLayer):
self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition * self.depth
self.weight = Parameter(
torch.empty((self.num_embeddings_per_partition, self.embed_dim_per_partition),
device=get_current_device(),
dtype=dtype))
torch.empty(
(self.num_embeddings_per_partition, self.embed_dim_per_partition),
device=get_current_device(),
dtype=dtype,
)
)
self.reset_parameters(weight_initializer)
self._set_tensor_parallel_attributes()
@@ -1132,14 +1038,17 @@ class VocabParallelEmbedding3D(ParallelLayer):
self._fill_padding_idx_with_zero()
def _fill_padding_idx_with_zero(self) -> None:
if self.padding_idx is not None and \
self.padding_idx >= self.vocab_start_index and self.padding_idx < self.vocab_end_index:
if (
self.padding_idx is not None
and self.padding_idx >= self.vocab_start_index
and self.padding_idx < self.vocab_end_index
):
with torch.no_grad():
self.weight[self.padding_idx - self.vocab_start_index].fill_(0)
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict()
weight_key = prefix + 'weight'
weight_key = prefix + "weight"
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
# weight
weight = state_dict.pop(weight_key, None)
@@ -1147,8 +1056,7 @@ class VocabParallelEmbedding3D(ParallelLayer):
local_state[weight_key] = weight
# partition in output groups
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
gpc.get_local_rank(self.weight_parallel_mode) == 0:
if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0:
local_state = partition_tensor_parallel_state_dict(
local_state,
self.output_parallel_mode,
@@ -1174,7 +1082,7 @@ class VocabParallelEmbedding3D(ParallelLayer):
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight'
weight_key = prefix + "weight"
local_state = OrderedDict({weight_key: self.weight})
# gather in weight groups
@@ -1195,8 +1103,7 @@ class VocabParallelEmbedding3D(ParallelLayer):
keep_vars=keep_vars,
)
# gather in output groups
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
gpc.get_local_rank(self.weight_parallel_mode) == 0:
if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0:
local_state = gather_tensor_parallel_state_dict(
local_state,
self.output_parallel_mode,
@@ -1218,7 +1125,7 @@ class VocabParallelEmbedding3D(ParallelLayer):
output_parallel = F.embedding(masked_input, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
output_parallel[input_mask, :] = 0.
output_parallel[input_mask, :] = 0.0
output = reduce_scatter_tensor_3d(output_parallel, 0, self.input_parallel_mode)
return output