[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -10,6 +10,14 @@ from .layers import (
)
__all__ = [
'reduce_by_batch_3d', 'split_tensor_3d', 'split_batch_3d', 'Linear3D', 'LayerNorm3D', 'PatchEmbedding3D',
'Classifier3D', 'Embedding3D', 'VocabParallelEmbedding3D', 'VocabParallelClassifier3D'
"reduce_by_batch_3d",
"split_tensor_3d",
"split_batch_3d",
"Linear3D",
"LayerNorm3D",
"PatchEmbedding3D",
"Classifier3D",
"Embedding3D",
"VocabParallelEmbedding3D",
"VocabParallelClassifier3D",
]

View File

@@ -16,7 +16,6 @@ from ._utils import get_parallel_mode_from_env, push_async_grad
class _Linear3D(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(
@@ -52,7 +51,8 @@ class _Linear3D(torch.autograd.Function):
input_grad, input_op = reduce_scatter(input_grad, 0, ctx.input_parallel_mode, async_op=True)
weight_grad = torch.matmul(
input_.reshape(-1, input_.shape[-1]).transpose(0, 1), output_grad.reshape(-1, output_grad.shape[-1]))
input_.reshape(-1, input_.shape[-1]).transpose(0, 1), output_grad.reshape(-1, output_grad.shape[-1])
)
weight_grad, op = reduce_scatter(weight_grad, 0, ctx.weight_parallel_mode, async_op=True)
weight_grad = push_async_grad(op, weight_grad, ctx.weight_id)
@@ -92,7 +92,6 @@ def linear_3d(
class _Classifier3D(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(
@@ -131,7 +130,8 @@ class _Classifier3D(torch.autograd.Function):
def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]:
input_, weight = ctx.saved_tensors
weight_grad = torch.matmul(
output_grad.reshape(-1, output_grad.shape[-1]).transpose(0, 1), input_.reshape(-1, input_.shape[-1]))
output_grad.reshape(-1, output_grad.shape[-1]).transpose(0, 1), input_.reshape(-1, input_.shape[-1])
)
weight_grad = reduce(weight_grad, ctx.src_rank, ctx.input_parallel_mode)
if gpc.get_local_rank(ctx.input_parallel_mode) == gpc.get_local_rank(ctx.output_parallel_mode):
weight_grad, op = all_reduce(weight_grad, ctx.weight_parallel_mode, async_op=True)
@@ -187,7 +187,6 @@ def classifier_3d(
class _VocabParallelClassifier3D(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(
@@ -230,7 +229,8 @@ class _VocabParallelClassifier3D(torch.autograd.Function):
input_grad, input_op = reduce_scatter(input_grad, 0, ctx.input_parallel_mode, async_op=True)
weight_grad = torch.matmul(
input_.reshape(-1, input_.shape[-1]).transpose(0, 1), output_grad.reshape(-1, output_grad.shape[-1]))
input_.reshape(-1, input_.shape[-1]).transpose(0, 1), output_grad.reshape(-1, output_grad.shape[-1])
)
weight_grad, op = reduce_scatter(weight_grad.transpose(0, 1), 0, ctx.weight_parallel_mode, async_op=True)
weight_grad = push_async_grad(op, weight_grad, ctx.weight_id)
@@ -296,7 +296,7 @@ def norm_backward(grad: Tensor, mu: Tensor, sigma: Tensor, weight: Tensor):
# dbias, dweight = grad, grad * mu / sigma
dz = grad * weight
dmu = dz / sigma
dvar = dz * mu * (-0.5) * sigma**(-3)
dvar = dz * mu * (-0.5) * sigma ** (-3)
dmean = -dmu
dvar = torch.sum(dvar, -1, keepdim=True)
dmean = torch.sum(dmean, -1, keepdim=True)
@@ -305,7 +305,6 @@ def norm_backward(grad: Tensor, mu: Tensor, sigma: Tensor, weight: Tensor):
class _Layernorm3D(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(
@@ -415,20 +414,24 @@ def split_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Te
"""
dim_size = tensor.size(dim)
world_size = gpc.get_world_size(parallel_mode)
assert dim_size % world_size == 0, \
f'The dimension {dim} to split, size ({dim_size}) is not a multiple of world size ({world_size}), ' \
f'cannot split tensor evenly'
assert dim_size % world_size == 0, (
f"The dimension {dim} to split, size ({dim_size}) is not a multiple of world size ({world_size}), "
f"cannot split tensor evenly"
)
if tensor.size(dim) <= 1:
return tensor
output = torch.chunk(tensor, gpc.get_world_size(parallel_mode),
dim=dim)[gpc.get_local_rank(parallel_mode)].contiguous()
output = torch.chunk(tensor, gpc.get_world_size(parallel_mode), dim=dim)[
gpc.get_local_rank(parallel_mode)
].contiguous()
return output
def split_batch_3d(input_: Tensor,
dim: int = 0,
input_parallel_mode: ParallelMode = ParallelMode.PARALLEL_3D_INPUT,
weight_parallel_mode: ParallelMode = ParallelMode.PARALLEL_3D_WEIGHT) -> Tensor:
def split_batch_3d(
input_: Tensor,
dim: int = 0,
input_parallel_mode: ParallelMode = ParallelMode.PARALLEL_3D_INPUT,
weight_parallel_mode: ParallelMode = ParallelMode.PARALLEL_3D_WEIGHT,
) -> Tensor:
r"""Splits 3D tensor in batch.
Args:
@@ -456,7 +459,6 @@ def split_batch_3d(input_: Tensor,
class _ReduceTensor3D(torch.autograd.Function):
@staticmethod
def forward(ctx, input_, parallel_mode):
return all_reduce(input_, parallel_mode)
@@ -481,7 +483,6 @@ def reduce_tensor_3d(tensor: Tensor, parallel_mode: ParallelMode) -> Tensor:
class _AllGatherTensor3D(torch.autograd.Function):
@staticmethod
def forward(ctx, input_, dim, parallel_mode):
ctx.dim = dim
@@ -511,7 +512,6 @@ def all_gather_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMode)
class _ReduceScatterTensor3D(torch.autograd.Function):
@staticmethod
def forward(ctx, input_, dim, parallel_mode):
ctx.dim = dim
@@ -538,21 +538,23 @@ def reduce_scatter_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMo
"""
dim_size = tensor.size(dim)
world_size = gpc.get_world_size(parallel_mode)
assert dim_size % world_size == 0, \
f'The batch size ({dim_size}) is not a multiple of square of 3D depth ({world_size}).'
assert (
dim_size % world_size == 0
), f"The batch size ({dim_size}) is not a multiple of square of 3D depth ({world_size})."
return _ReduceScatterTensor3D.apply(tensor, dim, parallel_mode)
class _ReduceByBatch3D(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx,
input_: Tensor,
input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode,
reduce_mean: bool = False) -> Tensor:
def forward(
ctx,
input_: Tensor,
input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode,
reduce_mean: bool = False,
) -> Tensor:
output = all_reduce(input_, input_parallel_mode)
output = all_reduce(output, weight_parallel_mode)
ctx.reduce_mean = reduce_mean
@@ -571,10 +573,9 @@ class _ReduceByBatch3D(torch.autograd.Function):
return output_grad, None, None, None
def reduce_by_batch_3d(tensor: Tensor,
input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode,
reduce_mean: bool = False) -> Tensor:
def reduce_by_batch_3d(
tensor: Tensor, input_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode, reduce_mean: bool = False
) -> Tensor:
r"""All-reduce the input from the model parallel region.
Args:

View File

@@ -18,17 +18,24 @@ from colossalai.legacy.global_variables import tensor_parallel_env as env
def get_depth_from_env() -> int:
try:
depth = env.depth_3d
assert depth > 0, 'DEPTH must be greater than zero'
assert depth > 0, "DEPTH must be greater than zero"
return depth
except KeyError as e:
raise EnvironmentError('DEPTH is not found in the current environment, '
'please make sure that you have used the correct process group initializer')
except KeyError:
raise EnvironmentError(
"DEPTH is not found in the current environment, "
"please make sure that you have used the correct process group initializer"
)
def get_parallel_mode_from_env(group):
assert group in [INPUT_GROUP_3D, WEIGHT_GROUP_3D, OUTPUT_GROUP_3D, INPUT_X_WEIGHT_3D, OUTPUT_X_WEIGHT_3D], \
f'{group} is not valid for 3D tensor parallelism.'
assert group in [
INPUT_GROUP_3D,
WEIGHT_GROUP_3D,
OUTPUT_GROUP_3D,
INPUT_X_WEIGHT_3D,
OUTPUT_X_WEIGHT_3D,
], f"{group} is not valid for 3D tensor parallelism."
return getattr(env, group)
@@ -44,12 +51,10 @@ def dbg_check_shape(tensor: Tensor, shape: tuple):
rank = gpc.get_global_rank()
if rank == 0:
print(tensor.shape)
assert tensor.shape == shape, \
'{} does not match {}'.format(tensor.shape, shape)
assert tensor.shape == shape, "{} does not match {}".format(tensor.shape, shape)
class AsyncGradientBucket(object):
def __init__(self):
self.bucket = OrderedDict()

View File

@@ -59,7 +59,6 @@ class LayerNorm3D(ParallelLayer):
"""
def __init__(self, normalized_shape: int, eps: float = 1e-12, bias=True, dtype=None):
super().__init__()
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
@@ -70,10 +69,12 @@ class LayerNorm3D(ParallelLayer):
self.normalized_shape_per_partition = divide(normalized_shape, self.depth)
self.weight = Parameter(
torch.ones(self.normalized_shape_per_partition, device=get_current_device(), dtype=dtype))
torch.ones(self.normalized_shape_per_partition, device=get_current_device(), dtype=dtype)
)
if bias:
self.bias = Parameter(
torch.zeros(self.normalized_shape_per_partition, device=get_current_device(), dtype=dtype))
torch.zeros(self.normalized_shape_per_partition, device=get_current_device(), dtype=dtype)
)
else:
self.bias = None
self.variance_epsilon = eps
@@ -94,8 +95,8 @@ class LayerNorm3D(ParallelLayer):
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict()
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
weight_key = prefix + "weight"
bias_key = prefix + "bias"
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
# weight
weight = state_dict.pop(weight_key, None)
@@ -107,15 +108,11 @@ class LayerNorm3D(ParallelLayer):
local_state[bias_key] = bias
# partition in output groups
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
gpc.get_local_rank(self.weight_parallel_mode) == 0:
if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0:
local_state = partition_tensor_parallel_state_dict(
local_state,
self.output_parallel_mode,
dims={
weight_key: 0,
bias_key: 0
},
dims={weight_key: 0, bias_key: 0},
partition_states={
weight_key: True,
bias_key: True,
@@ -130,26 +127,19 @@ class LayerNorm3D(ParallelLayer):
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
weight_key = prefix + "weight"
bias_key = prefix + "bias"
local_state = OrderedDict({weight_key: self.weight})
if self.bias is not None:
local_state[bias_key] = self.bias
# gather in output groups
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
gpc.get_local_rank(self.weight_parallel_mode) == 0:
if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0:
local_state = gather_tensor_parallel_state_dict(
local_state,
self.output_parallel_mode,
dims={
weight_key: 0,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: True
},
dims={weight_key: 0, bias_key: 0},
partition_states={weight_key: True, bias_key: True},
keep_vars=keep_vars,
)
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
@@ -185,14 +175,16 @@ class Linear3D(ParallelLayer):
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
"""
def __init__(self,
in_features: int,
out_features: int,
bias: bool = True,
dtype: torch.dtype = None,
skip_bias_add: bool = False,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
dtype: torch.dtype = None,
skip_bias_add: bool = False,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
):
super().__init__()
self.in_features = in_features
self.out_features = out_features
@@ -207,13 +199,17 @@ class Linear3D(ParallelLayer):
self.bias_features_per_partition = divide(out_features, self.depth)
self.weight = Parameter(
torch.empty(self.in_features_per_partition,
self.out_features_per_partition,
device=get_current_device(),
dtype=dtype))
torch.empty(
self.in_features_per_partition,
self.out_features_per_partition,
device=get_current_device(),
dtype=dtype,
)
)
if bias:
self.bias = Parameter(
torch.zeros(self.bias_features_per_partition, device=get_current_device(), dtype=dtype))
torch.zeros(self.bias_features_per_partition, device=get_current_device(), dtype=dtype)
)
else:
self.bias = None
@@ -239,15 +235,17 @@ class Linear3D(ParallelLayer):
if self.bias is not None:
bias_initializer(self.bias, fan_in=fan_in)
broadcast(self.bias,
gpc.get_ranks_in_group(self.output_x_weight_parallel_mode)[0],
self.output_x_weight_parallel_mode)
broadcast(
self.bias,
gpc.get_ranks_in_group(self.output_x_weight_parallel_mode)[0],
self.output_x_weight_parallel_mode,
)
self.bias.register_hook(self._sync_grad_hook)
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict()
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
weight_key = prefix + "weight"
bias_key = prefix + "bias"
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
# weight
weight = state_dict.pop(weight_key, None)
@@ -260,53 +258,34 @@ class Linear3D(ParallelLayer):
local_state[bias_key] = bias
# partition in output groups
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
gpc.get_local_rank(self.weight_parallel_mode) == 0:
if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0:
local_state = partition_tensor_parallel_state_dict(
local_state,
self.output_parallel_mode,
dims={
weight_key: 0,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: False
},
dims={weight_key: 0, bias_key: 0},
partition_states={weight_key: True, bias_key: False},
)
# partition in input groups
if gpc.get_local_rank(self.weight_parallel_mode) == 0:
local_state = partition_tensor_parallel_state_dict(
local_state,
self.input_parallel_mode,
dims={
weight_key: -1,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: True
},
dims={weight_key: -1, bias_key: 0},
partition_states={weight_key: True, bias_key: True},
)
# partition in weight groups
local_state = partition_tensor_parallel_state_dict(
local_state,
self.weight_parallel_mode,
dims={
weight_key: 0,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: False
},
dims={weight_key: 0, bias_key: 0},
partition_states={weight_key: True, bias_key: False},
)
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
weight_key = prefix + "weight"
bias_key = prefix + "bias"
local_state = OrderedDict({weight_key: self.weight})
if self.bias is not None:
local_state[bias_key] = self.bias
@@ -315,14 +294,8 @@ class Linear3D(ParallelLayer):
local_state = gather_tensor_parallel_state_dict(
local_state,
self.weight_parallel_mode,
dims={
weight_key: 0,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: False
},
dims={weight_key: 0, bias_key: 0},
partition_states={weight_key: True, bias_key: False},
keep_vars=keep_vars,
)
# gather in input groups
@@ -330,30 +303,17 @@ class Linear3D(ParallelLayer):
local_state = gather_tensor_parallel_state_dict(
local_state,
self.input_parallel_mode,
dims={
weight_key: -1,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: True
},
dims={weight_key: -1, bias_key: 0},
partition_states={weight_key: True, bias_key: True},
keep_vars=keep_vars,
)
# gather in output groups
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
gpc.get_local_rank(self.weight_parallel_mode) == 0:
if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0:
local_state = gather_tensor_parallel_state_dict(
local_state,
self.output_parallel_mode,
dims={
weight_key: 0,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: False
},
dims={weight_key: 0, bias_key: 0},
partition_states={weight_key: True, bias_key: False},
keep_vars=keep_vars,
)
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
@@ -396,14 +356,16 @@ class Classifier3D(ParallelLayer):
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
"""
def __init__(self,
in_features: int,
num_classes: int,
weight: Parameter = None,
bias: bool = True,
dtype: torch.dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
def __init__(
self,
in_features: int,
num_classes: int,
weight: Parameter = None,
bias: bool = True,
dtype: torch.dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
):
super().__init__()
self.in_features = in_features
self.num_classes = num_classes
@@ -418,7 +380,8 @@ class Classifier3D(ParallelLayer):
self.has_weight = False
else:
self.weight = Parameter(
torch.empty(self.num_classes, self.in_features_per_partition, device=get_current_device(), dtype=dtype))
torch.empty(self.num_classes, self.in_features_per_partition, device=get_current_device(), dtype=dtype)
)
self.has_weight = True
if bias:
self.bias = Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype))
@@ -449,8 +412,8 @@ class Classifier3D(ParallelLayer):
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict()
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
weight_key = prefix + "weight"
bias_key = prefix + "bias"
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
# weight
if self.has_weight:
@@ -464,19 +427,12 @@ class Classifier3D(ParallelLayer):
local_state[bias_key] = bias
# partition in output groups
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
gpc.get_local_rank(self.weight_parallel_mode) == 0:
if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0:
local_state = partition_tensor_parallel_state_dict(
local_state,
self.output_parallel_mode,
dims={
weight_key: -1,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: False
},
dims={weight_key: -1, bias_key: 0},
partition_states={weight_key: True, bias_key: False},
)
# broadcast in input groups
if gpc.get_local_rank(self.weight_parallel_mode) == 0:
@@ -487,8 +443,8 @@ class Classifier3D(ParallelLayer):
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
weight_key = prefix + "weight"
bias_key = prefix + "bias"
local_state = OrderedDict()
if self.has_weight:
local_state[weight_key] = self.weight
@@ -496,19 +452,12 @@ class Classifier3D(ParallelLayer):
local_state[bias_key] = self.bias
# gather in output groups
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
gpc.get_local_rank(self.weight_parallel_mode) == 0:
if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0:
local_state = gather_tensor_parallel_state_dict(
local_state,
self.output_parallel_mode,
dims={
weight_key: -1,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: False
},
dims={weight_key: -1, bias_key: 0},
partition_states={weight_key: True, bias_key: False},
keep_vars=keep_vars,
)
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
@@ -544,14 +493,16 @@ class VocabParallelClassifier3D(ParallelLayer):
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
"""
def __init__(self,
in_features: int,
num_classes: int,
weight: Parameter = None,
bias: bool = True,
dtype: torch.dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
def __init__(
self,
in_features: int,
num_classes: int,
weight: Parameter = None,
bias: bool = True,
dtype: torch.dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
):
super().__init__()
self.in_features = in_features
self.num_classes = num_classes
@@ -569,14 +520,18 @@ class VocabParallelClassifier3D(ParallelLayer):
self.has_weight = False
else:
self.weight = Parameter(
torch.empty(self.out_features_per_partition,
self.in_features_per_partition,
device=get_current_device(),
dtype=dtype))
torch.empty(
self.out_features_per_partition,
self.in_features_per_partition,
device=get_current_device(),
dtype=dtype,
)
)
self.has_weight = True
if bias:
self.bias = Parameter(
torch.zeros(self.bias_features_per_partition, device=get_current_device(), dtype=dtype))
torch.zeros(self.bias_features_per_partition, device=get_current_device(), dtype=dtype)
)
else:
self.bias = None
@@ -602,15 +557,17 @@ class VocabParallelClassifier3D(ParallelLayer):
if self.bias is not None:
bias_initializer(self.bias, fan_in=fan_in)
broadcast(self.bias,
gpc.get_ranks_in_group(self.output_x_weight_parallel_mode)[0],
self.output_x_weight_parallel_mode)
broadcast(
self.bias,
gpc.get_ranks_in_group(self.output_x_weight_parallel_mode)[0],
self.output_x_weight_parallel_mode,
)
register_async_grad_hook(self.bias)
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict()
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
weight_key = prefix + "weight"
bias_key = prefix + "bias"
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
# weight
if self.has_weight:
@@ -624,53 +581,34 @@ class VocabParallelClassifier3D(ParallelLayer):
local_state[bias_key] = bias
# partition in output groups
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
gpc.get_local_rank(self.weight_parallel_mode) == 0:
if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0:
local_state = partition_tensor_parallel_state_dict(
local_state,
self.output_parallel_mode,
dims={
weight_key: -1,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: False
},
dims={weight_key: -1, bias_key: 0},
partition_states={weight_key: True, bias_key: False},
)
# partition in input groups
if gpc.get_local_rank(self.weight_parallel_mode) == 0:
local_state = partition_tensor_parallel_state_dict(
local_state,
self.input_parallel_mode,
dims={
weight_key: 0,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: True
},
dims={weight_key: 0, bias_key: 0},
partition_states={weight_key: True, bias_key: True},
)
# partition in weight groups
local_state = partition_tensor_parallel_state_dict(
local_state,
self.weight_parallel_mode,
dims={
weight_key: 0,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: False
},
dims={weight_key: 0, bias_key: 0},
partition_states={weight_key: True, bias_key: False},
)
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
weight_key = prefix + "weight"
bias_key = prefix + "bias"
local_state = OrderedDict({weight_key: self.weight})
if self.bias is not None:
local_state[bias_key] = self.bias
@@ -679,14 +617,8 @@ class VocabParallelClassifier3D(ParallelLayer):
local_state = gather_tensor_parallel_state_dict(
local_state,
self.weight_parallel_mode,
dims={
weight_key: 0,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: False
},
dims={weight_key: 0, bias_key: 0},
partition_states={weight_key: True, bias_key: False},
keep_vars=keep_vars,
)
# gather in input groups
@@ -694,30 +626,17 @@ class VocabParallelClassifier3D(ParallelLayer):
local_state = gather_tensor_parallel_state_dict(
local_state,
self.input_parallel_mode,
dims={
weight_key: 0,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: True
},
dims={weight_key: 0, bias_key: 0},
partition_states={weight_key: True, bias_key: True},
keep_vars=keep_vars,
)
# gather in output groups
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
gpc.get_local_rank(self.weight_parallel_mode) == 0:
if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0:
local_state = gather_tensor_parallel_state_dict(
local_state,
self.output_parallel_mode,
dims={
weight_key: -1,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: False
},
dims={weight_key: -1, bias_key: 0},
partition_states={weight_key: True, bias_key: False},
keep_vars=keep_vars,
)
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
@@ -756,16 +675,18 @@ class PatchEmbedding3D(ParallelLayer):
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
"""
def __init__(self,
img_size: int,
patch_size: int,
in_chans: int,
embed_size: int,
flatten: bool = True,
dtype: torch.dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
position_embed_initializer: Callable = init.zeros_()):
def __init__(
self,
img_size: int,
patch_size: int,
in_chans: int,
embed_size: int,
flatten: bool = True,
dtype: torch.dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
position_embed_initializer: Callable = init.zeros_(),
):
super().__init__()
self.depth = get_depth_from_env()
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
@@ -783,15 +704,18 @@ class PatchEmbedding3D(ParallelLayer):
self.flatten = flatten
self.weight = nn.Parameter(
torch.empty((embed_size_per_partition, in_chans, *self.patch_size),
device=get_current_device(),
dtype=dtype))
torch.empty(
(embed_size_per_partition, in_chans, *self.patch_size), device=get_current_device(), dtype=dtype
)
)
self.bias = nn.Parameter(torch.empty(embed_size_per_partition, device=get_current_device(), dtype=dtype))
self.cls_token = nn.Parameter(
torch.zeros((1, 1, embed_size_per_partition), device=get_current_device(), dtype=dtype))
torch.zeros((1, 1, embed_size_per_partition), device=get_current_device(), dtype=dtype)
)
self.pos_embed = nn.Parameter(
torch.zeros((1, self.num_patches + 1, embed_size_per_partition), device=get_current_device(), dtype=dtype))
torch.zeros((1, self.num_patches + 1, embed_size_per_partition), device=get_current_device(), dtype=dtype)
)
self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer)
self._set_tensor_parallel_attributes()
@@ -826,10 +750,10 @@ class PatchEmbedding3D(ParallelLayer):
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict()
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
cls_token_key = prefix + 'cls_token'
pos_embed_key = prefix + 'pos_embed'
weight_key = prefix + "weight"
bias_key = prefix + "bias"
cls_token_key = prefix + "cls_token"
pos_embed_key = prefix + "pos_embed"
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
# weight
weight = state_dict.pop(weight_key, None)
@@ -849,23 +773,12 @@ class PatchEmbedding3D(ParallelLayer):
local_state[pos_embed_key] = pos_embed
# partition in output groups
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
gpc.get_local_rank(self.weight_parallel_mode) == 0:
if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0:
local_state = partition_tensor_parallel_state_dict(
local_state,
self.output_parallel_mode,
dims={
weight_key: 0,
bias_key: 0,
cls_token_key: -1,
pos_embed_key: -1
},
partition_states={
weight_key: True,
bias_key: True,
cls_token_key: True,
pos_embed_key: True
},
dims={weight_key: 0, bias_key: 0, cls_token_key: -1, pos_embed_key: -1},
partition_states={weight_key: True, bias_key: True, cls_token_key: True, pos_embed_key: True},
)
# broadcast in input groups
if gpc.get_local_rank(self.weight_parallel_mode) == 0:
@@ -876,47 +789,33 @@ class PatchEmbedding3D(ParallelLayer):
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
cls_token_key = prefix + 'cls_token'
pos_embed_key = prefix + 'pos_embed'
local_state = OrderedDict({
weight_key: self.weight,
bias_key: self.bias,
cls_token_key: self.cls_token,
pos_embed_key: self.pos_embed
})
weight_key = prefix + "weight"
bias_key = prefix + "bias"
cls_token_key = prefix + "cls_token"
pos_embed_key = prefix + "pos_embed"
local_state = OrderedDict(
{weight_key: self.weight, bias_key: self.bias, cls_token_key: self.cls_token, pos_embed_key: self.pos_embed}
)
# gather in output groups
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
gpc.get_local_rank(self.weight_parallel_mode) == 0:
if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0:
local_state = gather_tensor_parallel_state_dict(
local_state,
self.output_parallel_mode,
dims={
weight_key: 0,
bias_key: 0,
cls_token_key: -1,
pos_embed_key: -1
},
partition_states={
weight_key: True,
bias_key: True,
cls_token_key: True,
pos_embed_key: True
},
dims={weight_key: 0, bias_key: 0, cls_token_key: -1, pos_embed_key: -1},
partition_states={weight_key: True, bias_key: True, cls_token_key: True, pos_embed_key: True},
keep_vars=keep_vars,
)
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
destination.update(local_state)
def forward(self, input_: Tensor) -> Tensor:
input_ = split_batch_3d(input_,
input_parallel_mode=self.input_parallel_mode,
weight_parallel_mode=self.weight_parallel_mode)
input_ = split_batch_3d(
input_, input_parallel_mode=self.input_parallel_mode, weight_parallel_mode=self.weight_parallel_mode
)
output = F.conv2d(input_, self.weight, self.bias, stride=self.patch_size)
if self.flatten:
output = output.flatten(2).transpose(1, 2) # BCHW -> BNC
output = output.flatten(2).transpose(1, 2) # BCHW -> BNC
cls_token = self.cls_token.expand(output.shape[0], -1, -1)
output = torch.cat((cls_token, output), dim=1)
@@ -956,14 +855,16 @@ class Embedding3D(ParallelLayer):
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_
"""
def __init__(self,
num_embeddings: int,
embedding_dim: int,
padding_idx: int = None,
dtype: torch.dtype = None,
weight_initializer: Callable = init.normal_(),
*args,
**kwargs):
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
padding_idx: int = None,
dtype: torch.dtype = None,
weight_initializer: Callable = init.normal_(),
*args,
**kwargs,
):
super().__init__()
self.depth = get_depth_from_env()
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
@@ -979,7 +880,8 @@ class Embedding3D(ParallelLayer):
self.embed_kwargs = kwargs
self.weight = nn.Parameter(
torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype))
torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype)
)
self.reset_parameters(weight_initializer)
self._set_tensor_parallel_attributes()
@@ -996,8 +898,9 @@ class Embedding3D(ParallelLayer):
fan_in, fan_out = self.num_embeddings, self.embed_dim
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
self._fill_padding_idx_with_zero()
broadcast(self.weight,
gpc.get_ranks_in_group(self.input_x_weight_parallel_mode)[0], self.input_x_weight_parallel_mode)
broadcast(
self.weight, gpc.get_ranks_in_group(self.input_x_weight_parallel_mode)[0], self.input_x_weight_parallel_mode
)
self.weight.register_hook(self._sync_grad_hook)
def _fill_padding_idx_with_zero(self) -> None:
@@ -1007,7 +910,7 @@ class Embedding3D(ParallelLayer):
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict()
weight_key = prefix + 'weight'
weight_key = prefix + "weight"
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
# weight
weight = state_dict.pop(weight_key, None)
@@ -1015,8 +918,7 @@ class Embedding3D(ParallelLayer):
local_state[weight_key] = weight
# partition in output groups
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
gpc.get_local_rank(self.weight_parallel_mode) == 0:
if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0:
local_state = partition_tensor_parallel_state_dict(
local_state,
self.output_parallel_mode,
@@ -1032,12 +934,11 @@ class Embedding3D(ParallelLayer):
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight'
weight_key = prefix + "weight"
local_state = OrderedDict({weight_key: self.weight})
# gather in output groups
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
gpc.get_local_rank(self.weight_parallel_mode) == 0:
if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0:
local_state = gather_tensor_parallel_state_dict(
local_state,
self.output_parallel_mode,
@@ -1049,9 +950,9 @@ class Embedding3D(ParallelLayer):
destination.update(local_state)
def forward(self, input_: Tensor) -> Tensor:
input_ = split_batch_3d(input_,
input_parallel_mode=self.input_parallel_mode,
weight_parallel_mode=self.weight_parallel_mode)
input_ = split_batch_3d(
input_, input_parallel_mode=self.input_parallel_mode, weight_parallel_mode=self.weight_parallel_mode
)
output = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
return output
@@ -1088,14 +989,16 @@ class VocabParallelEmbedding3D(ParallelLayer):
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
"""
def __init__(self,
num_embeddings: int,
embedding_dim: int,
padding_idx: int = None,
dtype: torch.dtype = None,
weight_initializer: Callable = init.normal_(),
*args,
**kwargs):
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
padding_idx: int = None,
dtype: torch.dtype = None,
weight_initializer: Callable = init.normal_(),
*args,
**kwargs,
):
super().__init__()
self.num_embeddings = num_embeddings
self.embed_dim = embedding_dim
@@ -1114,9 +1017,12 @@ class VocabParallelEmbedding3D(ParallelLayer):
self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition * self.depth
self.weight = Parameter(
torch.empty((self.num_embeddings_per_partition, self.embed_dim_per_partition),
device=get_current_device(),
dtype=dtype))
torch.empty(
(self.num_embeddings_per_partition, self.embed_dim_per_partition),
device=get_current_device(),
dtype=dtype,
)
)
self.reset_parameters(weight_initializer)
self._set_tensor_parallel_attributes()
@@ -1132,14 +1038,17 @@ class VocabParallelEmbedding3D(ParallelLayer):
self._fill_padding_idx_with_zero()
def _fill_padding_idx_with_zero(self) -> None:
if self.padding_idx is not None and \
self.padding_idx >= self.vocab_start_index and self.padding_idx < self.vocab_end_index:
if (
self.padding_idx is not None
and self.padding_idx >= self.vocab_start_index
and self.padding_idx < self.vocab_end_index
):
with torch.no_grad():
self.weight[self.padding_idx - self.vocab_start_index].fill_(0)
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict()
weight_key = prefix + 'weight'
weight_key = prefix + "weight"
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
# weight
weight = state_dict.pop(weight_key, None)
@@ -1147,8 +1056,7 @@ class VocabParallelEmbedding3D(ParallelLayer):
local_state[weight_key] = weight
# partition in output groups
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
gpc.get_local_rank(self.weight_parallel_mode) == 0:
if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0:
local_state = partition_tensor_parallel_state_dict(
local_state,
self.output_parallel_mode,
@@ -1174,7 +1082,7 @@ class VocabParallelEmbedding3D(ParallelLayer):
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight'
weight_key = prefix + "weight"
local_state = OrderedDict({weight_key: self.weight})
# gather in weight groups
@@ -1195,8 +1103,7 @@ class VocabParallelEmbedding3D(ParallelLayer):
keep_vars=keep_vars,
)
# gather in output groups
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
gpc.get_local_rank(self.weight_parallel_mode) == 0:
if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0:
local_state = gather_tensor_parallel_state_dict(
local_state,
self.output_parallel_mode,
@@ -1218,7 +1125,7 @@ class VocabParallelEmbedding3D(ParallelLayer):
output_parallel = F.embedding(masked_input, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
output_parallel[input_mask, :] = 0.
output_parallel[input_mask, :] = 0.0
output = reduce_scatter_tensor_3d(output_parallel, 0, self.input_parallel_mode)
return output