mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -59,7 +59,6 @@ class LayerNorm3D(ParallelLayer):
|
||||
"""
|
||||
|
||||
def __init__(self, normalized_shape: int, eps: float = 1e-12, bias=True, dtype=None):
|
||||
|
||||
super().__init__()
|
||||
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
|
||||
@@ -70,10 +69,12 @@ class LayerNorm3D(ParallelLayer):
|
||||
self.normalized_shape_per_partition = divide(normalized_shape, self.depth)
|
||||
|
||||
self.weight = Parameter(
|
||||
torch.ones(self.normalized_shape_per_partition, device=get_current_device(), dtype=dtype))
|
||||
torch.ones(self.normalized_shape_per_partition, device=get_current_device(), dtype=dtype)
|
||||
)
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.zeros(self.normalized_shape_per_partition, device=get_current_device(), dtype=dtype))
|
||||
torch.zeros(self.normalized_shape_per_partition, device=get_current_device(), dtype=dtype)
|
||||
)
|
||||
else:
|
||||
self.bias = None
|
||||
self.variance_epsilon = eps
|
||||
@@ -94,8 +95,8 @@ class LayerNorm3D(ParallelLayer):
|
||||
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
weight_key = prefix + "weight"
|
||||
bias_key = prefix + "bias"
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
# weight
|
||||
weight = state_dict.pop(weight_key, None)
|
||||
@@ -107,15 +108,11 @@ class LayerNorm3D(ParallelLayer):
|
||||
local_state[bias_key] = bias
|
||||
|
||||
# partition in output groups
|
||||
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
|
||||
gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
local_state = partition_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
self.output_parallel_mode,
|
||||
dims={
|
||||
weight_key: 0,
|
||||
bias_key: 0
|
||||
},
|
||||
dims={weight_key: 0, bias_key: 0},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: True,
|
||||
@@ -130,26 +127,19 @@ class LayerNorm3D(ParallelLayer):
|
||||
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
|
||||
|
||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
weight_key = prefix + "weight"
|
||||
bias_key = prefix + "bias"
|
||||
local_state = OrderedDict({weight_key: self.weight})
|
||||
if self.bias is not None:
|
||||
local_state[bias_key] = self.bias
|
||||
|
||||
# gather in output groups
|
||||
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
|
||||
gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
local_state = gather_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
self.output_parallel_mode,
|
||||
dims={
|
||||
weight_key: 0,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: True
|
||||
},
|
||||
dims={weight_key: 0, bias_key: 0},
|
||||
partition_states={weight_key: True, bias_key: True},
|
||||
keep_vars=keep_vars,
|
||||
)
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
@@ -185,14 +175,16 @@ class Linear3D(ParallelLayer):
|
||||
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
dtype: torch.dtype = None,
|
||||
skip_bias_add: bool = False,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
dtype: torch.dtype = None,
|
||||
skip_bias_add: bool = False,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
):
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
@@ -207,13 +199,17 @@ class Linear3D(ParallelLayer):
|
||||
self.bias_features_per_partition = divide(out_features, self.depth)
|
||||
|
||||
self.weight = Parameter(
|
||||
torch.empty(self.in_features_per_partition,
|
||||
self.out_features_per_partition,
|
||||
device=get_current_device(),
|
||||
dtype=dtype))
|
||||
torch.empty(
|
||||
self.in_features_per_partition,
|
||||
self.out_features_per_partition,
|
||||
device=get_current_device(),
|
||||
dtype=dtype,
|
||||
)
|
||||
)
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.zeros(self.bias_features_per_partition, device=get_current_device(), dtype=dtype))
|
||||
torch.zeros(self.bias_features_per_partition, device=get_current_device(), dtype=dtype)
|
||||
)
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
@@ -239,15 +235,17 @@ class Linear3D(ParallelLayer):
|
||||
|
||||
if self.bias is not None:
|
||||
bias_initializer(self.bias, fan_in=fan_in)
|
||||
broadcast(self.bias,
|
||||
gpc.get_ranks_in_group(self.output_x_weight_parallel_mode)[0],
|
||||
self.output_x_weight_parallel_mode)
|
||||
broadcast(
|
||||
self.bias,
|
||||
gpc.get_ranks_in_group(self.output_x_weight_parallel_mode)[0],
|
||||
self.output_x_weight_parallel_mode,
|
||||
)
|
||||
self.bias.register_hook(self._sync_grad_hook)
|
||||
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
weight_key = prefix + "weight"
|
||||
bias_key = prefix + "bias"
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
# weight
|
||||
weight = state_dict.pop(weight_key, None)
|
||||
@@ -260,53 +258,34 @@ class Linear3D(ParallelLayer):
|
||||
local_state[bias_key] = bias
|
||||
|
||||
# partition in output groups
|
||||
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
|
||||
gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
local_state = partition_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
self.output_parallel_mode,
|
||||
dims={
|
||||
weight_key: 0,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: False
|
||||
},
|
||||
dims={weight_key: 0, bias_key: 0},
|
||||
partition_states={weight_key: True, bias_key: False},
|
||||
)
|
||||
# partition in input groups
|
||||
if gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
local_state = partition_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
self.input_parallel_mode,
|
||||
dims={
|
||||
weight_key: -1,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: True
|
||||
},
|
||||
dims={weight_key: -1, bias_key: 0},
|
||||
partition_states={weight_key: True, bias_key: True},
|
||||
)
|
||||
# partition in weight groups
|
||||
local_state = partition_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
self.weight_parallel_mode,
|
||||
dims={
|
||||
weight_key: 0,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: False
|
||||
},
|
||||
dims={weight_key: 0, bias_key: 0},
|
||||
partition_states={weight_key: True, bias_key: False},
|
||||
)
|
||||
|
||||
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
|
||||
|
||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
weight_key = prefix + "weight"
|
||||
bias_key = prefix + "bias"
|
||||
local_state = OrderedDict({weight_key: self.weight})
|
||||
if self.bias is not None:
|
||||
local_state[bias_key] = self.bias
|
||||
@@ -315,14 +294,8 @@ class Linear3D(ParallelLayer):
|
||||
local_state = gather_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
self.weight_parallel_mode,
|
||||
dims={
|
||||
weight_key: 0,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: False
|
||||
},
|
||||
dims={weight_key: 0, bias_key: 0},
|
||||
partition_states={weight_key: True, bias_key: False},
|
||||
keep_vars=keep_vars,
|
||||
)
|
||||
# gather in input groups
|
||||
@@ -330,30 +303,17 @@ class Linear3D(ParallelLayer):
|
||||
local_state = gather_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
self.input_parallel_mode,
|
||||
dims={
|
||||
weight_key: -1,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: True
|
||||
},
|
||||
dims={weight_key: -1, bias_key: 0},
|
||||
partition_states={weight_key: True, bias_key: True},
|
||||
keep_vars=keep_vars,
|
||||
)
|
||||
# gather in output groups
|
||||
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
|
||||
gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
local_state = gather_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
self.output_parallel_mode,
|
||||
dims={
|
||||
weight_key: 0,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: False
|
||||
},
|
||||
dims={weight_key: 0, bias_key: 0},
|
||||
partition_states={weight_key: True, bias_key: False},
|
||||
keep_vars=keep_vars,
|
||||
)
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
@@ -396,14 +356,16 @@ class Classifier3D(ParallelLayer):
|
||||
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_features: int,
|
||||
num_classes: int,
|
||||
weight: Parameter = None,
|
||||
bias: bool = True,
|
||||
dtype: torch.dtype = None,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
num_classes: int,
|
||||
weight: Parameter = None,
|
||||
bias: bool = True,
|
||||
dtype: torch.dtype = None,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
):
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
self.num_classes = num_classes
|
||||
@@ -418,7 +380,8 @@ class Classifier3D(ParallelLayer):
|
||||
self.has_weight = False
|
||||
else:
|
||||
self.weight = Parameter(
|
||||
torch.empty(self.num_classes, self.in_features_per_partition, device=get_current_device(), dtype=dtype))
|
||||
torch.empty(self.num_classes, self.in_features_per_partition, device=get_current_device(), dtype=dtype)
|
||||
)
|
||||
self.has_weight = True
|
||||
if bias:
|
||||
self.bias = Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype))
|
||||
@@ -449,8 +412,8 @@ class Classifier3D(ParallelLayer):
|
||||
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
weight_key = prefix + "weight"
|
||||
bias_key = prefix + "bias"
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
# weight
|
||||
if self.has_weight:
|
||||
@@ -464,19 +427,12 @@ class Classifier3D(ParallelLayer):
|
||||
local_state[bias_key] = bias
|
||||
|
||||
# partition in output groups
|
||||
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
|
||||
gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
local_state = partition_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
self.output_parallel_mode,
|
||||
dims={
|
||||
weight_key: -1,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: False
|
||||
},
|
||||
dims={weight_key: -1, bias_key: 0},
|
||||
partition_states={weight_key: True, bias_key: False},
|
||||
)
|
||||
# broadcast in input groups
|
||||
if gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
@@ -487,8 +443,8 @@ class Classifier3D(ParallelLayer):
|
||||
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
|
||||
|
||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
weight_key = prefix + "weight"
|
||||
bias_key = prefix + "bias"
|
||||
local_state = OrderedDict()
|
||||
if self.has_weight:
|
||||
local_state[weight_key] = self.weight
|
||||
@@ -496,19 +452,12 @@ class Classifier3D(ParallelLayer):
|
||||
local_state[bias_key] = self.bias
|
||||
|
||||
# gather in output groups
|
||||
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
|
||||
gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
local_state = gather_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
self.output_parallel_mode,
|
||||
dims={
|
||||
weight_key: -1,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: False
|
||||
},
|
||||
dims={weight_key: -1, bias_key: 0},
|
||||
partition_states={weight_key: True, bias_key: False},
|
||||
keep_vars=keep_vars,
|
||||
)
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
@@ -544,14 +493,16 @@ class VocabParallelClassifier3D(ParallelLayer):
|
||||
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_features: int,
|
||||
num_classes: int,
|
||||
weight: Parameter = None,
|
||||
bias: bool = True,
|
||||
dtype: torch.dtype = None,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
num_classes: int,
|
||||
weight: Parameter = None,
|
||||
bias: bool = True,
|
||||
dtype: torch.dtype = None,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
):
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
self.num_classes = num_classes
|
||||
@@ -569,14 +520,18 @@ class VocabParallelClassifier3D(ParallelLayer):
|
||||
self.has_weight = False
|
||||
else:
|
||||
self.weight = Parameter(
|
||||
torch.empty(self.out_features_per_partition,
|
||||
self.in_features_per_partition,
|
||||
device=get_current_device(),
|
||||
dtype=dtype))
|
||||
torch.empty(
|
||||
self.out_features_per_partition,
|
||||
self.in_features_per_partition,
|
||||
device=get_current_device(),
|
||||
dtype=dtype,
|
||||
)
|
||||
)
|
||||
self.has_weight = True
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.zeros(self.bias_features_per_partition, device=get_current_device(), dtype=dtype))
|
||||
torch.zeros(self.bias_features_per_partition, device=get_current_device(), dtype=dtype)
|
||||
)
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
@@ -602,15 +557,17 @@ class VocabParallelClassifier3D(ParallelLayer):
|
||||
|
||||
if self.bias is not None:
|
||||
bias_initializer(self.bias, fan_in=fan_in)
|
||||
broadcast(self.bias,
|
||||
gpc.get_ranks_in_group(self.output_x_weight_parallel_mode)[0],
|
||||
self.output_x_weight_parallel_mode)
|
||||
broadcast(
|
||||
self.bias,
|
||||
gpc.get_ranks_in_group(self.output_x_weight_parallel_mode)[0],
|
||||
self.output_x_weight_parallel_mode,
|
||||
)
|
||||
register_async_grad_hook(self.bias)
|
||||
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
weight_key = prefix + "weight"
|
||||
bias_key = prefix + "bias"
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
# weight
|
||||
if self.has_weight:
|
||||
@@ -624,53 +581,34 @@ class VocabParallelClassifier3D(ParallelLayer):
|
||||
local_state[bias_key] = bias
|
||||
|
||||
# partition in output groups
|
||||
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
|
||||
gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
local_state = partition_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
self.output_parallel_mode,
|
||||
dims={
|
||||
weight_key: -1,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: False
|
||||
},
|
||||
dims={weight_key: -1, bias_key: 0},
|
||||
partition_states={weight_key: True, bias_key: False},
|
||||
)
|
||||
# partition in input groups
|
||||
if gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
local_state = partition_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
self.input_parallel_mode,
|
||||
dims={
|
||||
weight_key: 0,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: True
|
||||
},
|
||||
dims={weight_key: 0, bias_key: 0},
|
||||
partition_states={weight_key: True, bias_key: True},
|
||||
)
|
||||
# partition in weight groups
|
||||
local_state = partition_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
self.weight_parallel_mode,
|
||||
dims={
|
||||
weight_key: 0,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: False
|
||||
},
|
||||
dims={weight_key: 0, bias_key: 0},
|
||||
partition_states={weight_key: True, bias_key: False},
|
||||
)
|
||||
|
||||
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
|
||||
|
||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
weight_key = prefix + "weight"
|
||||
bias_key = prefix + "bias"
|
||||
local_state = OrderedDict({weight_key: self.weight})
|
||||
if self.bias is not None:
|
||||
local_state[bias_key] = self.bias
|
||||
@@ -679,14 +617,8 @@ class VocabParallelClassifier3D(ParallelLayer):
|
||||
local_state = gather_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
self.weight_parallel_mode,
|
||||
dims={
|
||||
weight_key: 0,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: False
|
||||
},
|
||||
dims={weight_key: 0, bias_key: 0},
|
||||
partition_states={weight_key: True, bias_key: False},
|
||||
keep_vars=keep_vars,
|
||||
)
|
||||
# gather in input groups
|
||||
@@ -694,30 +626,17 @@ class VocabParallelClassifier3D(ParallelLayer):
|
||||
local_state = gather_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
self.input_parallel_mode,
|
||||
dims={
|
||||
weight_key: 0,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: True
|
||||
},
|
||||
dims={weight_key: 0, bias_key: 0},
|
||||
partition_states={weight_key: True, bias_key: True},
|
||||
keep_vars=keep_vars,
|
||||
)
|
||||
# gather in output groups
|
||||
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
|
||||
gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
local_state = gather_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
self.output_parallel_mode,
|
||||
dims={
|
||||
weight_key: -1,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: False
|
||||
},
|
||||
dims={weight_key: -1, bias_key: 0},
|
||||
partition_states={weight_key: True, bias_key: False},
|
||||
keep_vars=keep_vars,
|
||||
)
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
@@ -756,16 +675,18 @@ class PatchEmbedding3D(ParallelLayer):
|
||||
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
img_size: int,
|
||||
patch_size: int,
|
||||
in_chans: int,
|
||||
embed_size: int,
|
||||
flatten: bool = True,
|
||||
dtype: torch.dtype = None,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
position_embed_initializer: Callable = init.zeros_()):
|
||||
def __init__(
|
||||
self,
|
||||
img_size: int,
|
||||
patch_size: int,
|
||||
in_chans: int,
|
||||
embed_size: int,
|
||||
flatten: bool = True,
|
||||
dtype: torch.dtype = None,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
position_embed_initializer: Callable = init.zeros_(),
|
||||
):
|
||||
super().__init__()
|
||||
self.depth = get_depth_from_env()
|
||||
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
@@ -783,15 +704,18 @@ class PatchEmbedding3D(ParallelLayer):
|
||||
self.flatten = flatten
|
||||
|
||||
self.weight = nn.Parameter(
|
||||
torch.empty((embed_size_per_partition, in_chans, *self.patch_size),
|
||||
device=get_current_device(),
|
||||
dtype=dtype))
|
||||
torch.empty(
|
||||
(embed_size_per_partition, in_chans, *self.patch_size), device=get_current_device(), dtype=dtype
|
||||
)
|
||||
)
|
||||
self.bias = nn.Parameter(torch.empty(embed_size_per_partition, device=get_current_device(), dtype=dtype))
|
||||
|
||||
self.cls_token = nn.Parameter(
|
||||
torch.zeros((1, 1, embed_size_per_partition), device=get_current_device(), dtype=dtype))
|
||||
torch.zeros((1, 1, embed_size_per_partition), device=get_current_device(), dtype=dtype)
|
||||
)
|
||||
self.pos_embed = nn.Parameter(
|
||||
torch.zeros((1, self.num_patches + 1, embed_size_per_partition), device=get_current_device(), dtype=dtype))
|
||||
torch.zeros((1, self.num_patches + 1, embed_size_per_partition), device=get_current_device(), dtype=dtype)
|
||||
)
|
||||
|
||||
self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer)
|
||||
self._set_tensor_parallel_attributes()
|
||||
@@ -826,10 +750,10 @@ class PatchEmbedding3D(ParallelLayer):
|
||||
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
cls_token_key = prefix + 'cls_token'
|
||||
pos_embed_key = prefix + 'pos_embed'
|
||||
weight_key = prefix + "weight"
|
||||
bias_key = prefix + "bias"
|
||||
cls_token_key = prefix + "cls_token"
|
||||
pos_embed_key = prefix + "pos_embed"
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
# weight
|
||||
weight = state_dict.pop(weight_key, None)
|
||||
@@ -849,23 +773,12 @@ class PatchEmbedding3D(ParallelLayer):
|
||||
local_state[pos_embed_key] = pos_embed
|
||||
|
||||
# partition in output groups
|
||||
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
|
||||
gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
local_state = partition_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
self.output_parallel_mode,
|
||||
dims={
|
||||
weight_key: 0,
|
||||
bias_key: 0,
|
||||
cls_token_key: -1,
|
||||
pos_embed_key: -1
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: True,
|
||||
cls_token_key: True,
|
||||
pos_embed_key: True
|
||||
},
|
||||
dims={weight_key: 0, bias_key: 0, cls_token_key: -1, pos_embed_key: -1},
|
||||
partition_states={weight_key: True, bias_key: True, cls_token_key: True, pos_embed_key: True},
|
||||
)
|
||||
# broadcast in input groups
|
||||
if gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
@@ -876,47 +789,33 @@ class PatchEmbedding3D(ParallelLayer):
|
||||
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
|
||||
|
||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
cls_token_key = prefix + 'cls_token'
|
||||
pos_embed_key = prefix + 'pos_embed'
|
||||
local_state = OrderedDict({
|
||||
weight_key: self.weight,
|
||||
bias_key: self.bias,
|
||||
cls_token_key: self.cls_token,
|
||||
pos_embed_key: self.pos_embed
|
||||
})
|
||||
weight_key = prefix + "weight"
|
||||
bias_key = prefix + "bias"
|
||||
cls_token_key = prefix + "cls_token"
|
||||
pos_embed_key = prefix + "pos_embed"
|
||||
local_state = OrderedDict(
|
||||
{weight_key: self.weight, bias_key: self.bias, cls_token_key: self.cls_token, pos_embed_key: self.pos_embed}
|
||||
)
|
||||
|
||||
# gather in output groups
|
||||
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
|
||||
gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
local_state = gather_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
self.output_parallel_mode,
|
||||
dims={
|
||||
weight_key: 0,
|
||||
bias_key: 0,
|
||||
cls_token_key: -1,
|
||||
pos_embed_key: -1
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: True,
|
||||
cls_token_key: True,
|
||||
pos_embed_key: True
|
||||
},
|
||||
dims={weight_key: 0, bias_key: 0, cls_token_key: -1, pos_embed_key: -1},
|
||||
partition_states={weight_key: True, bias_key: True, cls_token_key: True, pos_embed_key: True},
|
||||
keep_vars=keep_vars,
|
||||
)
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
destination.update(local_state)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
input_ = split_batch_3d(input_,
|
||||
input_parallel_mode=self.input_parallel_mode,
|
||||
weight_parallel_mode=self.weight_parallel_mode)
|
||||
input_ = split_batch_3d(
|
||||
input_, input_parallel_mode=self.input_parallel_mode, weight_parallel_mode=self.weight_parallel_mode
|
||||
)
|
||||
output = F.conv2d(input_, self.weight, self.bias, stride=self.patch_size)
|
||||
if self.flatten:
|
||||
output = output.flatten(2).transpose(1, 2) # BCHW -> BNC
|
||||
output = output.flatten(2).transpose(1, 2) # BCHW -> BNC
|
||||
|
||||
cls_token = self.cls_token.expand(output.shape[0], -1, -1)
|
||||
output = torch.cat((cls_token, output), dim=1)
|
||||
@@ -956,14 +855,16 @@ class Embedding3D(ParallelLayer):
|
||||
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
padding_idx: int = None,
|
||||
dtype: torch.dtype = None,
|
||||
weight_initializer: Callable = init.normal_(),
|
||||
*args,
|
||||
**kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
padding_idx: int = None,
|
||||
dtype: torch.dtype = None,
|
||||
weight_initializer: Callable = init.normal_(),
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.depth = get_depth_from_env()
|
||||
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
@@ -979,7 +880,8 @@ class Embedding3D(ParallelLayer):
|
||||
self.embed_kwargs = kwargs
|
||||
|
||||
self.weight = nn.Parameter(
|
||||
torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype))
|
||||
torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype)
|
||||
)
|
||||
|
||||
self.reset_parameters(weight_initializer)
|
||||
self._set_tensor_parallel_attributes()
|
||||
@@ -996,8 +898,9 @@ class Embedding3D(ParallelLayer):
|
||||
fan_in, fan_out = self.num_embeddings, self.embed_dim
|
||||
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
||||
self._fill_padding_idx_with_zero()
|
||||
broadcast(self.weight,
|
||||
gpc.get_ranks_in_group(self.input_x_weight_parallel_mode)[0], self.input_x_weight_parallel_mode)
|
||||
broadcast(
|
||||
self.weight, gpc.get_ranks_in_group(self.input_x_weight_parallel_mode)[0], self.input_x_weight_parallel_mode
|
||||
)
|
||||
self.weight.register_hook(self._sync_grad_hook)
|
||||
|
||||
def _fill_padding_idx_with_zero(self) -> None:
|
||||
@@ -1007,7 +910,7 @@ class Embedding3D(ParallelLayer):
|
||||
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
weight_key = prefix + "weight"
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
# weight
|
||||
weight = state_dict.pop(weight_key, None)
|
||||
@@ -1015,8 +918,7 @@ class Embedding3D(ParallelLayer):
|
||||
local_state[weight_key] = weight
|
||||
|
||||
# partition in output groups
|
||||
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
|
||||
gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
local_state = partition_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
self.output_parallel_mode,
|
||||
@@ -1032,12 +934,11 @@ class Embedding3D(ParallelLayer):
|
||||
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
|
||||
|
||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
weight_key = prefix + "weight"
|
||||
local_state = OrderedDict({weight_key: self.weight})
|
||||
|
||||
# gather in output groups
|
||||
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
|
||||
gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
local_state = gather_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
self.output_parallel_mode,
|
||||
@@ -1049,9 +950,9 @@ class Embedding3D(ParallelLayer):
|
||||
destination.update(local_state)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
input_ = split_batch_3d(input_,
|
||||
input_parallel_mode=self.input_parallel_mode,
|
||||
weight_parallel_mode=self.weight_parallel_mode)
|
||||
input_ = split_batch_3d(
|
||||
input_, input_parallel_mode=self.input_parallel_mode, weight_parallel_mode=self.weight_parallel_mode
|
||||
)
|
||||
output = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
|
||||
|
||||
return output
|
||||
@@ -1088,14 +989,16 @@ class VocabParallelEmbedding3D(ParallelLayer):
|
||||
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
padding_idx: int = None,
|
||||
dtype: torch.dtype = None,
|
||||
weight_initializer: Callable = init.normal_(),
|
||||
*args,
|
||||
**kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
padding_idx: int = None,
|
||||
dtype: torch.dtype = None,
|
||||
weight_initializer: Callable = init.normal_(),
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_embeddings = num_embeddings
|
||||
self.embed_dim = embedding_dim
|
||||
@@ -1114,9 +1017,12 @@ class VocabParallelEmbedding3D(ParallelLayer):
|
||||
self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition * self.depth
|
||||
|
||||
self.weight = Parameter(
|
||||
torch.empty((self.num_embeddings_per_partition, self.embed_dim_per_partition),
|
||||
device=get_current_device(),
|
||||
dtype=dtype))
|
||||
torch.empty(
|
||||
(self.num_embeddings_per_partition, self.embed_dim_per_partition),
|
||||
device=get_current_device(),
|
||||
dtype=dtype,
|
||||
)
|
||||
)
|
||||
|
||||
self.reset_parameters(weight_initializer)
|
||||
self._set_tensor_parallel_attributes()
|
||||
@@ -1132,14 +1038,17 @@ class VocabParallelEmbedding3D(ParallelLayer):
|
||||
self._fill_padding_idx_with_zero()
|
||||
|
||||
def _fill_padding_idx_with_zero(self) -> None:
|
||||
if self.padding_idx is not None and \
|
||||
self.padding_idx >= self.vocab_start_index and self.padding_idx < self.vocab_end_index:
|
||||
if (
|
||||
self.padding_idx is not None
|
||||
and self.padding_idx >= self.vocab_start_index
|
||||
and self.padding_idx < self.vocab_end_index
|
||||
):
|
||||
with torch.no_grad():
|
||||
self.weight[self.padding_idx - self.vocab_start_index].fill_(0)
|
||||
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
weight_key = prefix + "weight"
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
# weight
|
||||
weight = state_dict.pop(weight_key, None)
|
||||
@@ -1147,8 +1056,7 @@ class VocabParallelEmbedding3D(ParallelLayer):
|
||||
local_state[weight_key] = weight
|
||||
|
||||
# partition in output groups
|
||||
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
|
||||
gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
local_state = partition_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
self.output_parallel_mode,
|
||||
@@ -1174,7 +1082,7 @@ class VocabParallelEmbedding3D(ParallelLayer):
|
||||
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
|
||||
|
||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
weight_key = prefix + "weight"
|
||||
local_state = OrderedDict({weight_key: self.weight})
|
||||
|
||||
# gather in weight groups
|
||||
@@ -1195,8 +1103,7 @@ class VocabParallelEmbedding3D(ParallelLayer):
|
||||
keep_vars=keep_vars,
|
||||
)
|
||||
# gather in output groups
|
||||
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
|
||||
gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
local_state = gather_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
self.output_parallel_mode,
|
||||
@@ -1218,7 +1125,7 @@ class VocabParallelEmbedding3D(ParallelLayer):
|
||||
|
||||
output_parallel = F.embedding(masked_input, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
|
||||
|
||||
output_parallel[input_mask, :] = 0.
|
||||
output_parallel[input_mask, :] = 0.0
|
||||
output = reduce_scatter_tensor_3d(output_parallel, 0, self.input_parallel_mode)
|
||||
|
||||
return output
|
||||
|
Reference in New Issue
Block a user