mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 03:52:01 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -10,6 +10,14 @@ from .layers import (
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'reduce_by_batch_3d', 'split_tensor_3d', 'split_batch_3d', 'Linear3D', 'LayerNorm3D', 'PatchEmbedding3D',
|
||||
'Classifier3D', 'Embedding3D', 'VocabParallelEmbedding3D', 'VocabParallelClassifier3D'
|
||||
"reduce_by_batch_3d",
|
||||
"split_tensor_3d",
|
||||
"split_batch_3d",
|
||||
"Linear3D",
|
||||
"LayerNorm3D",
|
||||
"PatchEmbedding3D",
|
||||
"Classifier3D",
|
||||
"Embedding3D",
|
||||
"VocabParallelEmbedding3D",
|
||||
"VocabParallelClassifier3D",
|
||||
]
|
||||
|
@@ -16,7 +16,6 @@ from ._utils import get_parallel_mode_from_env, push_async_grad
|
||||
|
||||
|
||||
class _Linear3D(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float16)
|
||||
def forward(
|
||||
@@ -52,7 +51,8 @@ class _Linear3D(torch.autograd.Function):
|
||||
input_grad, input_op = reduce_scatter(input_grad, 0, ctx.input_parallel_mode, async_op=True)
|
||||
|
||||
weight_grad = torch.matmul(
|
||||
input_.reshape(-1, input_.shape[-1]).transpose(0, 1), output_grad.reshape(-1, output_grad.shape[-1]))
|
||||
input_.reshape(-1, input_.shape[-1]).transpose(0, 1), output_grad.reshape(-1, output_grad.shape[-1])
|
||||
)
|
||||
weight_grad, op = reduce_scatter(weight_grad, 0, ctx.weight_parallel_mode, async_op=True)
|
||||
weight_grad = push_async_grad(op, weight_grad, ctx.weight_id)
|
||||
|
||||
@@ -92,7 +92,6 @@ def linear_3d(
|
||||
|
||||
|
||||
class _Classifier3D(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float16)
|
||||
def forward(
|
||||
@@ -131,7 +130,8 @@ class _Classifier3D(torch.autograd.Function):
|
||||
def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||
input_, weight = ctx.saved_tensors
|
||||
weight_grad = torch.matmul(
|
||||
output_grad.reshape(-1, output_grad.shape[-1]).transpose(0, 1), input_.reshape(-1, input_.shape[-1]))
|
||||
output_grad.reshape(-1, output_grad.shape[-1]).transpose(0, 1), input_.reshape(-1, input_.shape[-1])
|
||||
)
|
||||
weight_grad = reduce(weight_grad, ctx.src_rank, ctx.input_parallel_mode)
|
||||
if gpc.get_local_rank(ctx.input_parallel_mode) == gpc.get_local_rank(ctx.output_parallel_mode):
|
||||
weight_grad, op = all_reduce(weight_grad, ctx.weight_parallel_mode, async_op=True)
|
||||
@@ -187,7 +187,6 @@ def classifier_3d(
|
||||
|
||||
|
||||
class _VocabParallelClassifier3D(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float16)
|
||||
def forward(
|
||||
@@ -230,7 +229,8 @@ class _VocabParallelClassifier3D(torch.autograd.Function):
|
||||
input_grad, input_op = reduce_scatter(input_grad, 0, ctx.input_parallel_mode, async_op=True)
|
||||
|
||||
weight_grad = torch.matmul(
|
||||
input_.reshape(-1, input_.shape[-1]).transpose(0, 1), output_grad.reshape(-1, output_grad.shape[-1]))
|
||||
input_.reshape(-1, input_.shape[-1]).transpose(0, 1), output_grad.reshape(-1, output_grad.shape[-1])
|
||||
)
|
||||
weight_grad, op = reduce_scatter(weight_grad.transpose(0, 1), 0, ctx.weight_parallel_mode, async_op=True)
|
||||
weight_grad = push_async_grad(op, weight_grad, ctx.weight_id)
|
||||
|
||||
@@ -296,7 +296,7 @@ def norm_backward(grad: Tensor, mu: Tensor, sigma: Tensor, weight: Tensor):
|
||||
# dbias, dweight = grad, grad * mu / sigma
|
||||
dz = grad * weight
|
||||
dmu = dz / sigma
|
||||
dvar = dz * mu * (-0.5) * sigma**(-3)
|
||||
dvar = dz * mu * (-0.5) * sigma ** (-3)
|
||||
dmean = -dmu
|
||||
dvar = torch.sum(dvar, -1, keepdim=True)
|
||||
dmean = torch.sum(dmean, -1, keepdim=True)
|
||||
@@ -305,7 +305,6 @@ def norm_backward(grad: Tensor, mu: Tensor, sigma: Tensor, weight: Tensor):
|
||||
|
||||
|
||||
class _Layernorm3D(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float32)
|
||||
def forward(
|
||||
@@ -415,20 +414,24 @@ def split_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Te
|
||||
"""
|
||||
dim_size = tensor.size(dim)
|
||||
world_size = gpc.get_world_size(parallel_mode)
|
||||
assert dim_size % world_size == 0, \
|
||||
f'The dimension {dim} to split, size ({dim_size}) is not a multiple of world size ({world_size}), ' \
|
||||
f'cannot split tensor evenly'
|
||||
assert dim_size % world_size == 0, (
|
||||
f"The dimension {dim} to split, size ({dim_size}) is not a multiple of world size ({world_size}), "
|
||||
f"cannot split tensor evenly"
|
||||
)
|
||||
if tensor.size(dim) <= 1:
|
||||
return tensor
|
||||
output = torch.chunk(tensor, gpc.get_world_size(parallel_mode),
|
||||
dim=dim)[gpc.get_local_rank(parallel_mode)].contiguous()
|
||||
output = torch.chunk(tensor, gpc.get_world_size(parallel_mode), dim=dim)[
|
||||
gpc.get_local_rank(parallel_mode)
|
||||
].contiguous()
|
||||
return output
|
||||
|
||||
|
||||
def split_batch_3d(input_: Tensor,
|
||||
dim: int = 0,
|
||||
input_parallel_mode: ParallelMode = ParallelMode.PARALLEL_3D_INPUT,
|
||||
weight_parallel_mode: ParallelMode = ParallelMode.PARALLEL_3D_WEIGHT) -> Tensor:
|
||||
def split_batch_3d(
|
||||
input_: Tensor,
|
||||
dim: int = 0,
|
||||
input_parallel_mode: ParallelMode = ParallelMode.PARALLEL_3D_INPUT,
|
||||
weight_parallel_mode: ParallelMode = ParallelMode.PARALLEL_3D_WEIGHT,
|
||||
) -> Tensor:
|
||||
r"""Splits 3D tensor in batch.
|
||||
|
||||
Args:
|
||||
@@ -456,7 +459,6 @@ def split_batch_3d(input_: Tensor,
|
||||
|
||||
|
||||
class _ReduceTensor3D(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, parallel_mode):
|
||||
return all_reduce(input_, parallel_mode)
|
||||
@@ -481,7 +483,6 @@ def reduce_tensor_3d(tensor: Tensor, parallel_mode: ParallelMode) -> Tensor:
|
||||
|
||||
|
||||
class _AllGatherTensor3D(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, dim, parallel_mode):
|
||||
ctx.dim = dim
|
||||
@@ -511,7 +512,6 @@ def all_gather_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMode)
|
||||
|
||||
|
||||
class _ReduceScatterTensor3D(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, dim, parallel_mode):
|
||||
ctx.dim = dim
|
||||
@@ -538,21 +538,23 @@ def reduce_scatter_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMo
|
||||
"""
|
||||
dim_size = tensor.size(dim)
|
||||
world_size = gpc.get_world_size(parallel_mode)
|
||||
assert dim_size % world_size == 0, \
|
||||
f'The batch size ({dim_size}) is not a multiple of square of 3D depth ({world_size}).'
|
||||
assert (
|
||||
dim_size % world_size == 0
|
||||
), f"The batch size ({dim_size}) is not a multiple of square of 3D depth ({world_size})."
|
||||
|
||||
return _ReduceScatterTensor3D.apply(tensor, dim, parallel_mode)
|
||||
|
||||
|
||||
class _ReduceByBatch3D(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float32)
|
||||
def forward(ctx,
|
||||
input_: Tensor,
|
||||
input_parallel_mode: ParallelMode,
|
||||
weight_parallel_mode: ParallelMode,
|
||||
reduce_mean: bool = False) -> Tensor:
|
||||
def forward(
|
||||
ctx,
|
||||
input_: Tensor,
|
||||
input_parallel_mode: ParallelMode,
|
||||
weight_parallel_mode: ParallelMode,
|
||||
reduce_mean: bool = False,
|
||||
) -> Tensor:
|
||||
output = all_reduce(input_, input_parallel_mode)
|
||||
output = all_reduce(output, weight_parallel_mode)
|
||||
ctx.reduce_mean = reduce_mean
|
||||
@@ -571,10 +573,9 @@ class _ReduceByBatch3D(torch.autograd.Function):
|
||||
return output_grad, None, None, None
|
||||
|
||||
|
||||
def reduce_by_batch_3d(tensor: Tensor,
|
||||
input_parallel_mode: ParallelMode,
|
||||
weight_parallel_mode: ParallelMode,
|
||||
reduce_mean: bool = False) -> Tensor:
|
||||
def reduce_by_batch_3d(
|
||||
tensor: Tensor, input_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode, reduce_mean: bool = False
|
||||
) -> Tensor:
|
||||
r"""All-reduce the input from the model parallel region.
|
||||
|
||||
Args:
|
||||
|
@@ -18,17 +18,24 @@ from colossalai.legacy.global_variables import tensor_parallel_env as env
|
||||
def get_depth_from_env() -> int:
|
||||
try:
|
||||
depth = env.depth_3d
|
||||
assert depth > 0, 'DEPTH must be greater than zero'
|
||||
assert depth > 0, "DEPTH must be greater than zero"
|
||||
return depth
|
||||
|
||||
except KeyError as e:
|
||||
raise EnvironmentError('DEPTH is not found in the current environment, '
|
||||
'please make sure that you have used the correct process group initializer')
|
||||
except KeyError:
|
||||
raise EnvironmentError(
|
||||
"DEPTH is not found in the current environment, "
|
||||
"please make sure that you have used the correct process group initializer"
|
||||
)
|
||||
|
||||
|
||||
def get_parallel_mode_from_env(group):
|
||||
assert group in [INPUT_GROUP_3D, WEIGHT_GROUP_3D, OUTPUT_GROUP_3D, INPUT_X_WEIGHT_3D, OUTPUT_X_WEIGHT_3D], \
|
||||
f'{group} is not valid for 3D tensor parallelism.'
|
||||
assert group in [
|
||||
INPUT_GROUP_3D,
|
||||
WEIGHT_GROUP_3D,
|
||||
OUTPUT_GROUP_3D,
|
||||
INPUT_X_WEIGHT_3D,
|
||||
OUTPUT_X_WEIGHT_3D,
|
||||
], f"{group} is not valid for 3D tensor parallelism."
|
||||
return getattr(env, group)
|
||||
|
||||
|
||||
@@ -44,12 +51,10 @@ def dbg_check_shape(tensor: Tensor, shape: tuple):
|
||||
rank = gpc.get_global_rank()
|
||||
if rank == 0:
|
||||
print(tensor.shape)
|
||||
assert tensor.shape == shape, \
|
||||
'{} does not match {}'.format(tensor.shape, shape)
|
||||
assert tensor.shape == shape, "{} does not match {}".format(tensor.shape, shape)
|
||||
|
||||
|
||||
class AsyncGradientBucket(object):
|
||||
|
||||
def __init__(self):
|
||||
self.bucket = OrderedDict()
|
||||
|
||||
|
@@ -59,7 +59,6 @@ class LayerNorm3D(ParallelLayer):
|
||||
"""
|
||||
|
||||
def __init__(self, normalized_shape: int, eps: float = 1e-12, bias=True, dtype=None):
|
||||
|
||||
super().__init__()
|
||||
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
|
||||
@@ -70,10 +69,12 @@ class LayerNorm3D(ParallelLayer):
|
||||
self.normalized_shape_per_partition = divide(normalized_shape, self.depth)
|
||||
|
||||
self.weight = Parameter(
|
||||
torch.ones(self.normalized_shape_per_partition, device=get_current_device(), dtype=dtype))
|
||||
torch.ones(self.normalized_shape_per_partition, device=get_current_device(), dtype=dtype)
|
||||
)
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.zeros(self.normalized_shape_per_partition, device=get_current_device(), dtype=dtype))
|
||||
torch.zeros(self.normalized_shape_per_partition, device=get_current_device(), dtype=dtype)
|
||||
)
|
||||
else:
|
||||
self.bias = None
|
||||
self.variance_epsilon = eps
|
||||
@@ -94,8 +95,8 @@ class LayerNorm3D(ParallelLayer):
|
||||
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
weight_key = prefix + "weight"
|
||||
bias_key = prefix + "bias"
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
# weight
|
||||
weight = state_dict.pop(weight_key, None)
|
||||
@@ -107,15 +108,11 @@ class LayerNorm3D(ParallelLayer):
|
||||
local_state[bias_key] = bias
|
||||
|
||||
# partition in output groups
|
||||
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
|
||||
gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
local_state = partition_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
self.output_parallel_mode,
|
||||
dims={
|
||||
weight_key: 0,
|
||||
bias_key: 0
|
||||
},
|
||||
dims={weight_key: 0, bias_key: 0},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: True,
|
||||
@@ -130,26 +127,19 @@ class LayerNorm3D(ParallelLayer):
|
||||
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
|
||||
|
||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
weight_key = prefix + "weight"
|
||||
bias_key = prefix + "bias"
|
||||
local_state = OrderedDict({weight_key: self.weight})
|
||||
if self.bias is not None:
|
||||
local_state[bias_key] = self.bias
|
||||
|
||||
# gather in output groups
|
||||
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
|
||||
gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
local_state = gather_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
self.output_parallel_mode,
|
||||
dims={
|
||||
weight_key: 0,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: True
|
||||
},
|
||||
dims={weight_key: 0, bias_key: 0},
|
||||
partition_states={weight_key: True, bias_key: True},
|
||||
keep_vars=keep_vars,
|
||||
)
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
@@ -185,14 +175,16 @@ class Linear3D(ParallelLayer):
|
||||
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
dtype: torch.dtype = None,
|
||||
skip_bias_add: bool = False,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
dtype: torch.dtype = None,
|
||||
skip_bias_add: bool = False,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
):
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
@@ -207,13 +199,17 @@ class Linear3D(ParallelLayer):
|
||||
self.bias_features_per_partition = divide(out_features, self.depth)
|
||||
|
||||
self.weight = Parameter(
|
||||
torch.empty(self.in_features_per_partition,
|
||||
self.out_features_per_partition,
|
||||
device=get_current_device(),
|
||||
dtype=dtype))
|
||||
torch.empty(
|
||||
self.in_features_per_partition,
|
||||
self.out_features_per_partition,
|
||||
device=get_current_device(),
|
||||
dtype=dtype,
|
||||
)
|
||||
)
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.zeros(self.bias_features_per_partition, device=get_current_device(), dtype=dtype))
|
||||
torch.zeros(self.bias_features_per_partition, device=get_current_device(), dtype=dtype)
|
||||
)
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
@@ -239,15 +235,17 @@ class Linear3D(ParallelLayer):
|
||||
|
||||
if self.bias is not None:
|
||||
bias_initializer(self.bias, fan_in=fan_in)
|
||||
broadcast(self.bias,
|
||||
gpc.get_ranks_in_group(self.output_x_weight_parallel_mode)[0],
|
||||
self.output_x_weight_parallel_mode)
|
||||
broadcast(
|
||||
self.bias,
|
||||
gpc.get_ranks_in_group(self.output_x_weight_parallel_mode)[0],
|
||||
self.output_x_weight_parallel_mode,
|
||||
)
|
||||
self.bias.register_hook(self._sync_grad_hook)
|
||||
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
weight_key = prefix + "weight"
|
||||
bias_key = prefix + "bias"
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
# weight
|
||||
weight = state_dict.pop(weight_key, None)
|
||||
@@ -260,53 +258,34 @@ class Linear3D(ParallelLayer):
|
||||
local_state[bias_key] = bias
|
||||
|
||||
# partition in output groups
|
||||
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
|
||||
gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
local_state = partition_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
self.output_parallel_mode,
|
||||
dims={
|
||||
weight_key: 0,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: False
|
||||
},
|
||||
dims={weight_key: 0, bias_key: 0},
|
||||
partition_states={weight_key: True, bias_key: False},
|
||||
)
|
||||
# partition in input groups
|
||||
if gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
local_state = partition_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
self.input_parallel_mode,
|
||||
dims={
|
||||
weight_key: -1,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: True
|
||||
},
|
||||
dims={weight_key: -1, bias_key: 0},
|
||||
partition_states={weight_key: True, bias_key: True},
|
||||
)
|
||||
# partition in weight groups
|
||||
local_state = partition_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
self.weight_parallel_mode,
|
||||
dims={
|
||||
weight_key: 0,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: False
|
||||
},
|
||||
dims={weight_key: 0, bias_key: 0},
|
||||
partition_states={weight_key: True, bias_key: False},
|
||||
)
|
||||
|
||||
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
|
||||
|
||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
weight_key = prefix + "weight"
|
||||
bias_key = prefix + "bias"
|
||||
local_state = OrderedDict({weight_key: self.weight})
|
||||
if self.bias is not None:
|
||||
local_state[bias_key] = self.bias
|
||||
@@ -315,14 +294,8 @@ class Linear3D(ParallelLayer):
|
||||
local_state = gather_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
self.weight_parallel_mode,
|
||||
dims={
|
||||
weight_key: 0,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: False
|
||||
},
|
||||
dims={weight_key: 0, bias_key: 0},
|
||||
partition_states={weight_key: True, bias_key: False},
|
||||
keep_vars=keep_vars,
|
||||
)
|
||||
# gather in input groups
|
||||
@@ -330,30 +303,17 @@ class Linear3D(ParallelLayer):
|
||||
local_state = gather_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
self.input_parallel_mode,
|
||||
dims={
|
||||
weight_key: -1,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: True
|
||||
},
|
||||
dims={weight_key: -1, bias_key: 0},
|
||||
partition_states={weight_key: True, bias_key: True},
|
||||
keep_vars=keep_vars,
|
||||
)
|
||||
# gather in output groups
|
||||
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
|
||||
gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
local_state = gather_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
self.output_parallel_mode,
|
||||
dims={
|
||||
weight_key: 0,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: False
|
||||
},
|
||||
dims={weight_key: 0, bias_key: 0},
|
||||
partition_states={weight_key: True, bias_key: False},
|
||||
keep_vars=keep_vars,
|
||||
)
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
@@ -396,14 +356,16 @@ class Classifier3D(ParallelLayer):
|
||||
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_features: int,
|
||||
num_classes: int,
|
||||
weight: Parameter = None,
|
||||
bias: bool = True,
|
||||
dtype: torch.dtype = None,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
num_classes: int,
|
||||
weight: Parameter = None,
|
||||
bias: bool = True,
|
||||
dtype: torch.dtype = None,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
):
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
self.num_classes = num_classes
|
||||
@@ -418,7 +380,8 @@ class Classifier3D(ParallelLayer):
|
||||
self.has_weight = False
|
||||
else:
|
||||
self.weight = Parameter(
|
||||
torch.empty(self.num_classes, self.in_features_per_partition, device=get_current_device(), dtype=dtype))
|
||||
torch.empty(self.num_classes, self.in_features_per_partition, device=get_current_device(), dtype=dtype)
|
||||
)
|
||||
self.has_weight = True
|
||||
if bias:
|
||||
self.bias = Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype))
|
||||
@@ -449,8 +412,8 @@ class Classifier3D(ParallelLayer):
|
||||
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
weight_key = prefix + "weight"
|
||||
bias_key = prefix + "bias"
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
# weight
|
||||
if self.has_weight:
|
||||
@@ -464,19 +427,12 @@ class Classifier3D(ParallelLayer):
|
||||
local_state[bias_key] = bias
|
||||
|
||||
# partition in output groups
|
||||
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
|
||||
gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
local_state = partition_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
self.output_parallel_mode,
|
||||
dims={
|
||||
weight_key: -1,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: False
|
||||
},
|
||||
dims={weight_key: -1, bias_key: 0},
|
||||
partition_states={weight_key: True, bias_key: False},
|
||||
)
|
||||
# broadcast in input groups
|
||||
if gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
@@ -487,8 +443,8 @@ class Classifier3D(ParallelLayer):
|
||||
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
|
||||
|
||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
weight_key = prefix + "weight"
|
||||
bias_key = prefix + "bias"
|
||||
local_state = OrderedDict()
|
||||
if self.has_weight:
|
||||
local_state[weight_key] = self.weight
|
||||
@@ -496,19 +452,12 @@ class Classifier3D(ParallelLayer):
|
||||
local_state[bias_key] = self.bias
|
||||
|
||||
# gather in output groups
|
||||
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
|
||||
gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
local_state = gather_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
self.output_parallel_mode,
|
||||
dims={
|
||||
weight_key: -1,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: False
|
||||
},
|
||||
dims={weight_key: -1, bias_key: 0},
|
||||
partition_states={weight_key: True, bias_key: False},
|
||||
keep_vars=keep_vars,
|
||||
)
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
@@ -544,14 +493,16 @@ class VocabParallelClassifier3D(ParallelLayer):
|
||||
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_features: int,
|
||||
num_classes: int,
|
||||
weight: Parameter = None,
|
||||
bias: bool = True,
|
||||
dtype: torch.dtype = None,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
num_classes: int,
|
||||
weight: Parameter = None,
|
||||
bias: bool = True,
|
||||
dtype: torch.dtype = None,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
):
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
self.num_classes = num_classes
|
||||
@@ -569,14 +520,18 @@ class VocabParallelClassifier3D(ParallelLayer):
|
||||
self.has_weight = False
|
||||
else:
|
||||
self.weight = Parameter(
|
||||
torch.empty(self.out_features_per_partition,
|
||||
self.in_features_per_partition,
|
||||
device=get_current_device(),
|
||||
dtype=dtype))
|
||||
torch.empty(
|
||||
self.out_features_per_partition,
|
||||
self.in_features_per_partition,
|
||||
device=get_current_device(),
|
||||
dtype=dtype,
|
||||
)
|
||||
)
|
||||
self.has_weight = True
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.zeros(self.bias_features_per_partition, device=get_current_device(), dtype=dtype))
|
||||
torch.zeros(self.bias_features_per_partition, device=get_current_device(), dtype=dtype)
|
||||
)
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
@@ -602,15 +557,17 @@ class VocabParallelClassifier3D(ParallelLayer):
|
||||
|
||||
if self.bias is not None:
|
||||
bias_initializer(self.bias, fan_in=fan_in)
|
||||
broadcast(self.bias,
|
||||
gpc.get_ranks_in_group(self.output_x_weight_parallel_mode)[0],
|
||||
self.output_x_weight_parallel_mode)
|
||||
broadcast(
|
||||
self.bias,
|
||||
gpc.get_ranks_in_group(self.output_x_weight_parallel_mode)[0],
|
||||
self.output_x_weight_parallel_mode,
|
||||
)
|
||||
register_async_grad_hook(self.bias)
|
||||
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
weight_key = prefix + "weight"
|
||||
bias_key = prefix + "bias"
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
# weight
|
||||
if self.has_weight:
|
||||
@@ -624,53 +581,34 @@ class VocabParallelClassifier3D(ParallelLayer):
|
||||
local_state[bias_key] = bias
|
||||
|
||||
# partition in output groups
|
||||
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
|
||||
gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
local_state = partition_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
self.output_parallel_mode,
|
||||
dims={
|
||||
weight_key: -1,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: False
|
||||
},
|
||||
dims={weight_key: -1, bias_key: 0},
|
||||
partition_states={weight_key: True, bias_key: False},
|
||||
)
|
||||
# partition in input groups
|
||||
if gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
local_state = partition_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
self.input_parallel_mode,
|
||||
dims={
|
||||
weight_key: 0,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: True
|
||||
},
|
||||
dims={weight_key: 0, bias_key: 0},
|
||||
partition_states={weight_key: True, bias_key: True},
|
||||
)
|
||||
# partition in weight groups
|
||||
local_state = partition_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
self.weight_parallel_mode,
|
||||
dims={
|
||||
weight_key: 0,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: False
|
||||
},
|
||||
dims={weight_key: 0, bias_key: 0},
|
||||
partition_states={weight_key: True, bias_key: False},
|
||||
)
|
||||
|
||||
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
|
||||
|
||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
weight_key = prefix + "weight"
|
||||
bias_key = prefix + "bias"
|
||||
local_state = OrderedDict({weight_key: self.weight})
|
||||
if self.bias is not None:
|
||||
local_state[bias_key] = self.bias
|
||||
@@ -679,14 +617,8 @@ class VocabParallelClassifier3D(ParallelLayer):
|
||||
local_state = gather_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
self.weight_parallel_mode,
|
||||
dims={
|
||||
weight_key: 0,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: False
|
||||
},
|
||||
dims={weight_key: 0, bias_key: 0},
|
||||
partition_states={weight_key: True, bias_key: False},
|
||||
keep_vars=keep_vars,
|
||||
)
|
||||
# gather in input groups
|
||||
@@ -694,30 +626,17 @@ class VocabParallelClassifier3D(ParallelLayer):
|
||||
local_state = gather_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
self.input_parallel_mode,
|
||||
dims={
|
||||
weight_key: 0,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: True
|
||||
},
|
||||
dims={weight_key: 0, bias_key: 0},
|
||||
partition_states={weight_key: True, bias_key: True},
|
||||
keep_vars=keep_vars,
|
||||
)
|
||||
# gather in output groups
|
||||
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
|
||||
gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
local_state = gather_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
self.output_parallel_mode,
|
||||
dims={
|
||||
weight_key: -1,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: False
|
||||
},
|
||||
dims={weight_key: -1, bias_key: 0},
|
||||
partition_states={weight_key: True, bias_key: False},
|
||||
keep_vars=keep_vars,
|
||||
)
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
@@ -756,16 +675,18 @@ class PatchEmbedding3D(ParallelLayer):
|
||||
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
img_size: int,
|
||||
patch_size: int,
|
||||
in_chans: int,
|
||||
embed_size: int,
|
||||
flatten: bool = True,
|
||||
dtype: torch.dtype = None,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
position_embed_initializer: Callable = init.zeros_()):
|
||||
def __init__(
|
||||
self,
|
||||
img_size: int,
|
||||
patch_size: int,
|
||||
in_chans: int,
|
||||
embed_size: int,
|
||||
flatten: bool = True,
|
||||
dtype: torch.dtype = None,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
position_embed_initializer: Callable = init.zeros_(),
|
||||
):
|
||||
super().__init__()
|
||||
self.depth = get_depth_from_env()
|
||||
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
@@ -783,15 +704,18 @@ class PatchEmbedding3D(ParallelLayer):
|
||||
self.flatten = flatten
|
||||
|
||||
self.weight = nn.Parameter(
|
||||
torch.empty((embed_size_per_partition, in_chans, *self.patch_size),
|
||||
device=get_current_device(),
|
||||
dtype=dtype))
|
||||
torch.empty(
|
||||
(embed_size_per_partition, in_chans, *self.patch_size), device=get_current_device(), dtype=dtype
|
||||
)
|
||||
)
|
||||
self.bias = nn.Parameter(torch.empty(embed_size_per_partition, device=get_current_device(), dtype=dtype))
|
||||
|
||||
self.cls_token = nn.Parameter(
|
||||
torch.zeros((1, 1, embed_size_per_partition), device=get_current_device(), dtype=dtype))
|
||||
torch.zeros((1, 1, embed_size_per_partition), device=get_current_device(), dtype=dtype)
|
||||
)
|
||||
self.pos_embed = nn.Parameter(
|
||||
torch.zeros((1, self.num_patches + 1, embed_size_per_partition), device=get_current_device(), dtype=dtype))
|
||||
torch.zeros((1, self.num_patches + 1, embed_size_per_partition), device=get_current_device(), dtype=dtype)
|
||||
)
|
||||
|
||||
self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer)
|
||||
self._set_tensor_parallel_attributes()
|
||||
@@ -826,10 +750,10 @@ class PatchEmbedding3D(ParallelLayer):
|
||||
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
cls_token_key = prefix + 'cls_token'
|
||||
pos_embed_key = prefix + 'pos_embed'
|
||||
weight_key = prefix + "weight"
|
||||
bias_key = prefix + "bias"
|
||||
cls_token_key = prefix + "cls_token"
|
||||
pos_embed_key = prefix + "pos_embed"
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
# weight
|
||||
weight = state_dict.pop(weight_key, None)
|
||||
@@ -849,23 +773,12 @@ class PatchEmbedding3D(ParallelLayer):
|
||||
local_state[pos_embed_key] = pos_embed
|
||||
|
||||
# partition in output groups
|
||||
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
|
||||
gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
local_state = partition_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
self.output_parallel_mode,
|
||||
dims={
|
||||
weight_key: 0,
|
||||
bias_key: 0,
|
||||
cls_token_key: -1,
|
||||
pos_embed_key: -1
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: True,
|
||||
cls_token_key: True,
|
||||
pos_embed_key: True
|
||||
},
|
||||
dims={weight_key: 0, bias_key: 0, cls_token_key: -1, pos_embed_key: -1},
|
||||
partition_states={weight_key: True, bias_key: True, cls_token_key: True, pos_embed_key: True},
|
||||
)
|
||||
# broadcast in input groups
|
||||
if gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
@@ -876,47 +789,33 @@ class PatchEmbedding3D(ParallelLayer):
|
||||
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
|
||||
|
||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
cls_token_key = prefix + 'cls_token'
|
||||
pos_embed_key = prefix + 'pos_embed'
|
||||
local_state = OrderedDict({
|
||||
weight_key: self.weight,
|
||||
bias_key: self.bias,
|
||||
cls_token_key: self.cls_token,
|
||||
pos_embed_key: self.pos_embed
|
||||
})
|
||||
weight_key = prefix + "weight"
|
||||
bias_key = prefix + "bias"
|
||||
cls_token_key = prefix + "cls_token"
|
||||
pos_embed_key = prefix + "pos_embed"
|
||||
local_state = OrderedDict(
|
||||
{weight_key: self.weight, bias_key: self.bias, cls_token_key: self.cls_token, pos_embed_key: self.pos_embed}
|
||||
)
|
||||
|
||||
# gather in output groups
|
||||
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
|
||||
gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
local_state = gather_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
self.output_parallel_mode,
|
||||
dims={
|
||||
weight_key: 0,
|
||||
bias_key: 0,
|
||||
cls_token_key: -1,
|
||||
pos_embed_key: -1
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: True,
|
||||
cls_token_key: True,
|
||||
pos_embed_key: True
|
||||
},
|
||||
dims={weight_key: 0, bias_key: 0, cls_token_key: -1, pos_embed_key: -1},
|
||||
partition_states={weight_key: True, bias_key: True, cls_token_key: True, pos_embed_key: True},
|
||||
keep_vars=keep_vars,
|
||||
)
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
destination.update(local_state)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
input_ = split_batch_3d(input_,
|
||||
input_parallel_mode=self.input_parallel_mode,
|
||||
weight_parallel_mode=self.weight_parallel_mode)
|
||||
input_ = split_batch_3d(
|
||||
input_, input_parallel_mode=self.input_parallel_mode, weight_parallel_mode=self.weight_parallel_mode
|
||||
)
|
||||
output = F.conv2d(input_, self.weight, self.bias, stride=self.patch_size)
|
||||
if self.flatten:
|
||||
output = output.flatten(2).transpose(1, 2) # BCHW -> BNC
|
||||
output = output.flatten(2).transpose(1, 2) # BCHW -> BNC
|
||||
|
||||
cls_token = self.cls_token.expand(output.shape[0], -1, -1)
|
||||
output = torch.cat((cls_token, output), dim=1)
|
||||
@@ -956,14 +855,16 @@ class Embedding3D(ParallelLayer):
|
||||
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
padding_idx: int = None,
|
||||
dtype: torch.dtype = None,
|
||||
weight_initializer: Callable = init.normal_(),
|
||||
*args,
|
||||
**kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
padding_idx: int = None,
|
||||
dtype: torch.dtype = None,
|
||||
weight_initializer: Callable = init.normal_(),
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.depth = get_depth_from_env()
|
||||
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
@@ -979,7 +880,8 @@ class Embedding3D(ParallelLayer):
|
||||
self.embed_kwargs = kwargs
|
||||
|
||||
self.weight = nn.Parameter(
|
||||
torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype))
|
||||
torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype)
|
||||
)
|
||||
|
||||
self.reset_parameters(weight_initializer)
|
||||
self._set_tensor_parallel_attributes()
|
||||
@@ -996,8 +898,9 @@ class Embedding3D(ParallelLayer):
|
||||
fan_in, fan_out = self.num_embeddings, self.embed_dim
|
||||
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
||||
self._fill_padding_idx_with_zero()
|
||||
broadcast(self.weight,
|
||||
gpc.get_ranks_in_group(self.input_x_weight_parallel_mode)[0], self.input_x_weight_parallel_mode)
|
||||
broadcast(
|
||||
self.weight, gpc.get_ranks_in_group(self.input_x_weight_parallel_mode)[0], self.input_x_weight_parallel_mode
|
||||
)
|
||||
self.weight.register_hook(self._sync_grad_hook)
|
||||
|
||||
def _fill_padding_idx_with_zero(self) -> None:
|
||||
@@ -1007,7 +910,7 @@ class Embedding3D(ParallelLayer):
|
||||
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
weight_key = prefix + "weight"
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
# weight
|
||||
weight = state_dict.pop(weight_key, None)
|
||||
@@ -1015,8 +918,7 @@ class Embedding3D(ParallelLayer):
|
||||
local_state[weight_key] = weight
|
||||
|
||||
# partition in output groups
|
||||
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
|
||||
gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
local_state = partition_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
self.output_parallel_mode,
|
||||
@@ -1032,12 +934,11 @@ class Embedding3D(ParallelLayer):
|
||||
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
|
||||
|
||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
weight_key = prefix + "weight"
|
||||
local_state = OrderedDict({weight_key: self.weight})
|
||||
|
||||
# gather in output groups
|
||||
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
|
||||
gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
local_state = gather_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
self.output_parallel_mode,
|
||||
@@ -1049,9 +950,9 @@ class Embedding3D(ParallelLayer):
|
||||
destination.update(local_state)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
input_ = split_batch_3d(input_,
|
||||
input_parallel_mode=self.input_parallel_mode,
|
||||
weight_parallel_mode=self.weight_parallel_mode)
|
||||
input_ = split_batch_3d(
|
||||
input_, input_parallel_mode=self.input_parallel_mode, weight_parallel_mode=self.weight_parallel_mode
|
||||
)
|
||||
output = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
|
||||
|
||||
return output
|
||||
@@ -1088,14 +989,16 @@ class VocabParallelEmbedding3D(ParallelLayer):
|
||||
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
padding_idx: int = None,
|
||||
dtype: torch.dtype = None,
|
||||
weight_initializer: Callable = init.normal_(),
|
||||
*args,
|
||||
**kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
padding_idx: int = None,
|
||||
dtype: torch.dtype = None,
|
||||
weight_initializer: Callable = init.normal_(),
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_embeddings = num_embeddings
|
||||
self.embed_dim = embedding_dim
|
||||
@@ -1114,9 +1017,12 @@ class VocabParallelEmbedding3D(ParallelLayer):
|
||||
self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition * self.depth
|
||||
|
||||
self.weight = Parameter(
|
||||
torch.empty((self.num_embeddings_per_partition, self.embed_dim_per_partition),
|
||||
device=get_current_device(),
|
||||
dtype=dtype))
|
||||
torch.empty(
|
||||
(self.num_embeddings_per_partition, self.embed_dim_per_partition),
|
||||
device=get_current_device(),
|
||||
dtype=dtype,
|
||||
)
|
||||
)
|
||||
|
||||
self.reset_parameters(weight_initializer)
|
||||
self._set_tensor_parallel_attributes()
|
||||
@@ -1132,14 +1038,17 @@ class VocabParallelEmbedding3D(ParallelLayer):
|
||||
self._fill_padding_idx_with_zero()
|
||||
|
||||
def _fill_padding_idx_with_zero(self) -> None:
|
||||
if self.padding_idx is not None and \
|
||||
self.padding_idx >= self.vocab_start_index and self.padding_idx < self.vocab_end_index:
|
||||
if (
|
||||
self.padding_idx is not None
|
||||
and self.padding_idx >= self.vocab_start_index
|
||||
and self.padding_idx < self.vocab_end_index
|
||||
):
|
||||
with torch.no_grad():
|
||||
self.weight[self.padding_idx - self.vocab_start_index].fill_(0)
|
||||
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
weight_key = prefix + "weight"
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
# weight
|
||||
weight = state_dict.pop(weight_key, None)
|
||||
@@ -1147,8 +1056,7 @@ class VocabParallelEmbedding3D(ParallelLayer):
|
||||
local_state[weight_key] = weight
|
||||
|
||||
# partition in output groups
|
||||
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
|
||||
gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
local_state = partition_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
self.output_parallel_mode,
|
||||
@@ -1174,7 +1082,7 @@ class VocabParallelEmbedding3D(ParallelLayer):
|
||||
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
|
||||
|
||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
weight_key = prefix + "weight"
|
||||
local_state = OrderedDict({weight_key: self.weight})
|
||||
|
||||
# gather in weight groups
|
||||
@@ -1195,8 +1103,7 @@ class VocabParallelEmbedding3D(ParallelLayer):
|
||||
keep_vars=keep_vars,
|
||||
)
|
||||
# gather in output groups
|
||||
if gpc.get_local_rank(self.input_parallel_mode) == 0 and \
|
||||
gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0:
|
||||
local_state = gather_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
self.output_parallel_mode,
|
||||
@@ -1218,7 +1125,7 @@ class VocabParallelEmbedding3D(ParallelLayer):
|
||||
|
||||
output_parallel = F.embedding(masked_input, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
|
||||
|
||||
output_parallel[input_mask, :] = 0.
|
||||
output_parallel[input_mask, :] = 0.0
|
||||
output = reduce_scatter_tensor_3d(output_parallel, 0, self.input_parallel_mode)
|
||||
|
||||
return output
|
||||
|
Reference in New Issue
Block a user