From 0b8161fab800d1571d4d0e00ee4d399c62e66710 Mon Sep 17 00:00:00 2001 From: kurisusnowdeng Date: Wed, 26 Oct 2022 20:54:39 +0800 Subject: [PATCH] updated tp layers --- colossalai/constants.py | 2 + colossalai/context/parallel_mode.py | 2 + .../initializer_3d.py | 112 +++++- colossalai/global_variables.py | 10 +- colossalai/nn/layer/parallel_1d/_operation.py | 51 +++ colossalai/nn/layer/parallel_1d/layers.py | 29 +- colossalai/nn/layer/parallel_3d/_operation.py | 377 +++++++++++------- colossalai/nn/layer/parallel_3d/_utils.py | 89 ++++- colossalai/nn/layer/parallel_3d/layers.py | 169 +++++--- docker/Dockerfile | 6 +- .../test_3d/checks_3d/check_layer_3d.py | 79 ++-- tests/test_layers/test_3d/checks_3d/common.py | 6 +- tests/test_layers/test_3d/test_3d.py | 6 +- 13 files changed, 645 insertions(+), 293 deletions(-) diff --git a/colossalai/constants.py b/colossalai/constants.py index c8aaafdfa..6cf9085f9 100644 --- a/colossalai/constants.py +++ b/colossalai/constants.py @@ -23,6 +23,8 @@ INITIALIZER_MAPPING = { INPUT_GROUP_3D = 'input_group_3d' WEIGHT_GROUP_3D = 'weight_group_3d' OUTPUT_GROUP_3D = 'output_group_3d' +INPUT_X_WEIGHT_3D = 'input_x_weight_group_3d' +OUTPUT_X_WEIGHT_3D = 'output_x_weight_group_3d' # Attributes of tensor parallel parameters IS_TENSOR_PARALLEL = 'is_tensor_parallel' diff --git a/colossalai/context/parallel_mode.py b/colossalai/context/parallel_mode.py index dc50dca05..1cf6fa53d 100644 --- a/colossalai/context/parallel_mode.py +++ b/colossalai/context/parallel_mode.py @@ -39,6 +39,8 @@ class ParallelMode(Enum): PARALLEL_3D_INPUT = '3d_input' PARALLEL_3D_WEIGHT = '3d_weight' PARALLEL_3D_OUTPUT = '3d_output' + PARALLEL_3D_INPUT_X_WEIGHT = "3d_input_x_weight" + PARALLEL_3D_OUTPUT_X_WEIGHT = "3d_output_x_weight" # 2.5D parallel PARALLEL_2P5D_ROW = '2p5d_row' diff --git a/colossalai/context/process_group_initializer/initializer_3d.py b/colossalai/context/process_group_initializer/initializer_3d.py index 0cda7a52d..b752b8f45 100644 --- a/colossalai/context/process_group_initializer/initializer_3d.py +++ b/colossalai/context/process_group_initializer/initializer_3d.py @@ -176,6 +176,112 @@ class Initializer_3D_Output(ProcessGroupInitializer): return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode +class Initializer_3D_InputxWeight(ProcessGroupInitializer): + """3D tensor parallel initialization among input. + + Args: + num_group (int): The number of all tensor groups. + depth (int): Depth of 3D parallelism. + rank (int): The rank of current process. + world_size (int): Size of whole communication world. + config (Config): Running configuration. + data_parallel_size (int): Size of data parallel. + pipeline_parallel_size (int): Size of pipeline parallel. + tensor_parallel_size (int): Size of tensor parallel. + """ + + def __init__(self, num_group: int, depth: int, *args): + super().__init__(*args) + self.num_group = num_group + self.depth = depth + + def init_dist_group(self): + """Initialize 3D tensor parallel groups among input, and assign local_ranks and groups to each gpu. + + Returns: + Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode): + 3D tensor parallelism's information among input in a tuple. + """ + local_rank = None + ranks_in_group = None + process_group = None + cpu_group = None + group_world_size = None + mode = ParallelMode.PARALLEL_3D_INPUT_X_WEIGHT + env.input_x_weight_group_3d = mode + + for h in range(self.num_group): + for k in range(self.depth): + ranks = [ + h * self.depth**3 + i + self.depth * (j + self.depth * k) for j in range(self.depth) + for i in range(self.depth) + ] + group = dist.new_group(ranks) + group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group + + if self.rank in ranks: + local_rank = ranks.index(self.rank) + group_world_size = len(ranks) + process_group = group + cpu_group = group_cpu + ranks_in_group = ranks + + return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode + + +class Initializer_3D_OutputxWeight(ProcessGroupInitializer): + """3D tensor parallel initialization among input. + + Args: + num_group (int): The number of all tensor groups. + depth (int): Depth of 3D parallelism. + rank (int): The rank of current process. + world_size (int): Size of whole communication world. + config (Config): Running configuration. + data_parallel_size (int): Size of data parallel. + pipeline_parallel_size (int): Size of pipeline parallel. + tensor_parallel_size (int): Size of tensor parallel. + """ + + def __init__(self, num_group: int, depth: int, *args): + super().__init__(*args) + self.num_group = num_group + self.depth = depth + + def init_dist_group(self): + """Initialize 3D tensor parallel groups among input, and assign local_ranks and groups to each gpu. + + Returns: + Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode): + 3D tensor parallelism's information among input in a tuple. + """ + local_rank = None + ranks_in_group = None + process_group = None + cpu_group = None + group_world_size = None + mode = ParallelMode.PARALLEL_3D_OUTPUT_X_WEIGHT + env.output_x_weight_group_3d = mode + + for h in range(self.num_group): + for j in range(self.depth): + ranks = [ + h * self.depth**3 + i + self.depth * (j + self.depth * k) for k in range(self.depth) + for i in range(self.depth) + ] + group = dist.new_group(ranks) + group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group + + if self.rank in ranks: + local_rank = ranks.index(self.rank) + group_world_size = len(ranks) + process_group = group + cpu_group = group_cpu + ranks_in_group = ranks + + return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode + + @DIST_GROUP_INITIALIZER.register_module class Initializer_3D(ProcessGroupInitializer): """Serve as the single entry point to 3D parallel initialization. @@ -200,6 +306,8 @@ class Initializer_3D(ProcessGroupInitializer): self.input_initializer = Initializer_3D_Input(self.num_group, self.depth, *args) self.weight_initializer = Initializer_3D_Weight(self.num_group, self.depth, *args) self.output_initializer = Initializer_3D_Output(self.num_group, self.depth, *args) + self.input_x_weight_initializer = Initializer_3D_InputxWeight(self.num_group, self.depth, *args) + self.output_x_weight_initializer = Initializer_3D_OutputxWeight(self.num_group, self.depth, *args) def init_dist_group(self): """Initialize 3D tensor parallel groups, and assign local_ranks and groups to each gpu. @@ -211,6 +319,8 @@ class Initializer_3D(ProcessGroupInitializer): parallel_setting = [ self.input_initializer.init_dist_group(), self.weight_initializer.init_dist_group(), - self.output_initializer.init_dist_group() + self.output_initializer.init_dist_group(), + self.input_x_weight_initializer.init_dist_group(), + self.output_x_weight_initializer.init_dist_group() ] return parallel_setting diff --git a/colossalai/global_variables.py b/colossalai/global_variables.py index 24f8b60dd..e3575ea12 100644 --- a/colossalai/global_variables.py +++ b/colossalai/global_variables.py @@ -22,7 +22,9 @@ class TensorParallelEnv(object): depth_3d: int = None, input_group_3d=None, weight_group_3d=None, - output_group_3d=None): + output_group_3d=None, + input_x_weight_group_3d=None, + output_x_weight_group_3d=None): self.mode = mode self.vocab_parallel = vocab_parallel self.parallel_input_1d = parallel_input_1d @@ -33,6 +35,8 @@ class TensorParallelEnv(object): self.input_group_3d = input_group_3d self.weight_group_3d = weight_group_3d self.output_group_3d = output_group_3d + self.input_x_weight_group_3d = input_x_weight_group_3d + self.output_x_weight_group_3d = output_x_weight_group_3d def save(self): return dict(mode=self.mode, @@ -44,7 +48,9 @@ class TensorParallelEnv(object): depth_3d=self.depth_3d, input_group_3d=self.input_group_3d, weight_group_3d=self.weight_group_3d, - output_group_3d=self.output_group_3d) + output_group_3d=self.output_group_3d, + input_x_weight_group_3d=self.input_x_weight_group_3d, + output_x_weight_group_3d=self.output_x_weight_group_3d) tensor_parallel_env = TensorParallelEnv() diff --git a/colossalai/nn/layer/parallel_1d/_operation.py b/colossalai/nn/layer/parallel_1d/_operation.py index 7944598b7..394334558 100644 --- a/colossalai/nn/layer/parallel_1d/_operation.py +++ b/colossalai/nn/layer/parallel_1d/_operation.py @@ -1,4 +1,6 @@ import torch +import torch.distributed as dist +from colossalai.core import global_context as gpc try: import fused_mix_prec_layer_norm_cuda @@ -43,3 +45,52 @@ class FusedLayerNormAffineFunction1D(torch.autograd.Function): weight_, bias_, ctx.eps) return grad_input, grad_weight, grad_bias, None, None + + +class LinearWithAsyncCommunication(torch.autograd.Function): + """ + Linear layer execution with asynchronous communication in backprop. + """ + + @staticmethod + def forward(ctx, input_, weight, bias, parallel_mode, async_grad_allreduce): + ctx.save_for_backward(input_, weight) + ctx.use_bias = bias is not None + ctx.parallel_mode = parallel_mode + ctx.async_grad_allreduce = async_grad_allreduce + + output = torch.matmul(input_, weight.t()) + if bias is not None: + output = output + bias + return output + + @staticmethod + def backward(ctx, grad_output): + input, weight = ctx.saved_tensors + use_bias = ctx.use_bias + + total_input = input + grad_input = grad_output.matmul(weight) + + # Convert the tensor shapes to 2D for execution compatibility + grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2]) + total_input = total_input.view(total_input.shape[0] * total_input.shape[1], total_input.shape[2]) + + if ctx.async_grad_allreduce: + # Asynchronous all-reduce + handle = dist.all_reduce(grad_input, group=gpc.get_group(ctx.parallel_mode), async_op=True) + # Delay the start of weight gradient computation shortly (3us) to have + # all-reduce scheduled first and have GPU resources allocated + _ = torch.empty(1, device=grad_output.device) + 1 + + grad_weight = grad_output.t().matmul(total_input) + grad_bias = grad_output.sum(dim=0) if use_bias else None + + if ctx.async_grad_allreduce: + handle.wait() + + return grad_input, grad_weight, grad_bias, None, None, None + + +def linear_with_async_comm(input_, weight, bias, parallel_mode, async_grad_allreduce): + return LinearWithAsyncCommunication.apply(input_, weight, bias, parallel_mode, async_grad_allreduce) diff --git a/colossalai/nn/layer/parallel_1d/layers.py b/colossalai/nn/layer/parallel_1d/layers.py index fd26f67e8..0edc5e37b 100644 --- a/colossalai/nn/layer/parallel_1d/layers.py +++ b/colossalai/nn/layer/parallel_1d/layers.py @@ -20,12 +20,12 @@ from colossalai.utils.cuda import get_current_device from torch import Tensor from torch.nn.parameter import Parameter from ..vanilla import VanillaPatchEmbedding, VanillaLayerNorm - from ..base_layer import ParallelLayer from ..colossalai_layer._utils import ColossalaiModule from ..utils import divide, set_tensor_parallel_attribute_by_partition from ._utils import (gather_forward_split_backward, get_parallel_input, reduce_grad, reduce_input, set_parallel_input, split_forward_gather_backward) +from ._operation import linear_with_async_comm @LAYERS.register_module @@ -96,8 +96,25 @@ class LayerNorm1D(ColossalaiModule): dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. """ + _fast_ln_supported_sizes = [ + 1024, 1536, 2048, 2304, 3072, 3840, 4096, 5120, 6144, 8192, 10240, 12288, 12800, 15360, 16384, 18432, 20480, + 24576, 25600, 30720, 32768, 40960, 49152, 65536 + ] + def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None): - norm = VanillaLayerNorm(normalized_shape, eps=eps, bias=bias, dtype=dtype) + from apex.normalization import FusedLayerNorm + + fast_ln_installed = False + try: + from apex.contrib.layer_norm.layer_norm import FastLayerNorm + fast_ln_installed = True + except ImportError: + pass + + if fast_ln_installed and normalized_shape in self._fast_ln_supported_sizes: + norm = FastLayerNorm(normalized_shape, eps=eps).to(dtype) + else: + norm = FusedLayerNorm(normalized_shape, eps=eps).to(dtype) super().__init__(norm) def _load_from_state_dict(self, state_dict, prefix, *args): @@ -519,11 +536,12 @@ class Linear1D_Col(ParallelLayer): 'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format( input_.shape, self.weight.shape, self.weight.shape[-1]) # Set up backprop all-reduce. - input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D) + # input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D) + input_parallel = input_ # Matrix multiply. - bias = self.bias if not self.skip_bias_add else None - output_parallel = F.linear(input_parallel, self.weight, bias) + # output_parallel = F.linear(input_parallel, self.weight, bias) + output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, ParallelMode.PARALLEL_1D, True) if self.gather_output: # All-gather across the partitions. output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1) @@ -665,6 +683,7 @@ class Linear1D_Row(ParallelLayer): input_ = split_forward_gather_backward(input_, ParallelMode.PARALLEL_1D, dim=-1) output_parallel = F.linear(input_, self.weight) + # output_parallel = linear_with_async_comm(input_, self.weight, None, ParallelMode.PARALLEL_1D, False) output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D) if not self.skip_bias_add: diff --git a/colossalai/nn/layer/parallel_3d/_operation.py b/colossalai/nn/layer/parallel_3d/_operation.py index eb045f2b4..aeba5cc9d 100644 --- a/colossalai/nn/layer/parallel_3d/_operation.py +++ b/colossalai/nn/layer/parallel_3d/_operation.py @@ -9,7 +9,7 @@ from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc from torch import Tensor from torch.cuda.amp import custom_bwd, custom_fwd -from ._utils import get_parallel_mode_from_env +from ._utils import get_parallel_mode_from_env, push_async_grad from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D @@ -17,34 +17,27 @@ class _Linear3D(torch.autograd.Function): @staticmethod @custom_fwd(cast_inputs=torch.float16) - def forward(ctx, - input_: Tensor, - weight: Tensor, - bias: Optional[Tensor], - input_parallel_mode: ParallelMode, - weight_parallel_mode: ParallelMode, - output_parallel_mode: ParallelMode, - input_dim: int = 0, - weight_dim: int = -1, - output_dim: int = 0) -> Tensor: - ctx.use_bias = bias is not None - - input_ = all_gather(input_, input_dim, input_parallel_mode) - weight = all_gather(weight, weight_dim, weight_parallel_mode) - ctx.save_for_backward(input_, weight) - - output = torch.matmul(input_, weight) - output = reduce_scatter(output, output_dim, output_parallel_mode) - - if bias is not None: - output += bias - + def forward( + ctx, + input_: Tensor, + weight: Tensor, + weight_id: int, + input_parallel_mode: ParallelMode, + weight_parallel_mode: ParallelMode, + output_parallel_mode: ParallelMode, + ) -> Tensor: + ctx.weight_id = weight_id ctx.input_parallel_mode = input_parallel_mode ctx.weight_parallel_mode = weight_parallel_mode ctx.output_parallel_mode = output_parallel_mode - ctx.input_dim = input_dim - ctx.weight_dim = weight_dim - ctx.output_dim = output_dim + + input_ = all_gather(input_, 0, input_parallel_mode) + weight = all_gather(weight, -1, weight_parallel_mode) + ctx.save_for_backward(input_, weight) + + output = torch.matmul(input_, weight) + output = reduce_scatter(output, 0, output_parallel_mode) + return output @staticmethod @@ -52,73 +45,70 @@ class _Linear3D(torch.autograd.Function): def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]: input_, weight = ctx.saved_tensors with torch.no_grad(): - output_grad = all_gather(output_grad, ctx.output_dim, ctx.output_parallel_mode) - - async_ops = list() + output_grad = all_gather(output_grad, 0, ctx.output_parallel_mode) input_grad = torch.matmul(output_grad, weight.transpose(0, 1)) - input_grad, op = reduce_scatter(input_grad, ctx.input_dim, ctx.input_parallel_mode, async_op=True) - async_ops.append(op) + input_grad, input_op = reduce_scatter(input_grad, 0, ctx.input_parallel_mode, async_op=True) weight_grad = torch.matmul( input_.reshape(-1, input_.shape[-1]).transpose(0, 1), output_grad.reshape(-1, output_grad.shape[-1])) - weight_grad, op = reduce_scatter(weight_grad, ctx.weight_dim, ctx.weight_parallel_mode, async_op=True) - async_ops.append(op) + weight_grad, op = reduce_scatter(weight_grad, -1, ctx.weight_parallel_mode, async_op=True) + weight_grad = push_async_grad(op, weight_grad, ctx.weight_id) - if ctx.use_bias: - bias_grad = torch.sum(output_grad, dim=tuple(range(len(output_grad.shape))[:-1])) - bias_grad, op = all_reduce(bias_grad, ctx.weight_parallel_mode, async_op=True) - async_ops.append(op) - else: - bias_grad = None + input_op.wait() - for op in async_ops: - if op is not None: - op.wait() - - return input_grad, weight_grad, bias_grad, None, None, None, None, None, None + return input_grad, weight_grad, None, None, None, None -def linear_3d(input_: Tensor, - weight: Tensor, - bias: Optional[Tensor], - input_parallel_mode: ParallelMode, - weight_parallel_mode: ParallelMode, - output_parallel_mode: ParallelMode, - input_dim: int = 0, - weight_dim: int = -1, - output_dim: int = 0) -> Tensor: +def linear_3d( + input_: Tensor, + weight: Tensor, + input_parallel_mode: ParallelMode, + weight_parallel_mode: ParallelMode, + output_parallel_mode: ParallelMode, +) -> Tensor: r"""Linear layer for 3D parallelism. Args: input_ (:class:`torch.tensor`): input matrix. weight (:class:`torch.tensor`): matrix of weight. - bias (:class:`torch.tensor`): matrix of bias. input_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): input parallel mode. weight_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): weight parallel mode. output_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): output parallel mode. - input_dim (int, optional): dimension of input, defaults to 0. - weight_dim (int, optional): dimension of weight, defaults to -1. - output_dim (int, optional): dimension of output, defaults to 0. Note: The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found in `parallel_mode `_ """ - return _Linear3D.apply(input_, weight, bias, input_parallel_mode, weight_parallel_mode, output_parallel_mode, - input_dim, weight_dim, output_dim) + return _Linear3D.apply( + input_, + weight, + id(weight), + input_parallel_mode, + weight_parallel_mode, + output_parallel_mode, + ) class _Classifier3D(torch.autograd.Function): @staticmethod @custom_fwd(cast_inputs=torch.float16) - def forward(ctx, input_: Tensor, weight: Tensor, bias: Optional[Tensor], input_parallel_mode: ParallelMode, - weight_parallel_mode: ParallelMode, output_parallel_mode: ParallelMode) -> Tensor: + def forward( + ctx, + input_: Tensor, + weight: Tensor, + bias: Optional[Tensor], + weight_id: int, + bias_id: Optional[int], + input_parallel_mode: ParallelMode, + weight_parallel_mode: ParallelMode, + output_parallel_mode: ParallelMode, + ) -> Tensor: ctx.use_bias = bias is not None + ctx.weight_id = weight_id - ranks_in_group = gpc.get_ranks_in_group(input_parallel_mode) - src_rank = ranks_in_group[gpc.get_local_rank(output_parallel_mode)] + src_rank = gpc.get_ranks_in_group(input_parallel_mode)[gpc.get_local_rank(output_parallel_mode)] weight = broadcast(weight, src_rank, input_parallel_mode) ctx.save_for_backward(input_, weight) @@ -126,6 +116,7 @@ class _Classifier3D(torch.autograd.Function): output = all_reduce(output, output_parallel_mode) if bias is not None: + ctx.bias_id = bias_id output += bias ctx.src_rank = src_rank @@ -139,14 +130,12 @@ class _Classifier3D(torch.autograd.Function): def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]: input_, weight = ctx.saved_tensors with torch.no_grad(): - async_ops = list() - weight_grad = torch.matmul( output_grad.reshape(-1, output_grad.shape[-1]).transpose(0, 1), input_.reshape(-1, input_.shape[-1])) weight_grad = reduce(weight_grad, ctx.src_rank, ctx.input_parallel_mode) if gpc.get_local_rank(ctx.input_parallel_mode) == gpc.get_local_rank(ctx.output_parallel_mode): weight_grad, op = all_reduce(weight_grad, ctx.weight_parallel_mode, async_op=True) - async_ops.append(op) + weight_grad = push_async_grad(op, weight_grad, ctx.weight_id) else: weight_grad = None @@ -154,21 +143,23 @@ class _Classifier3D(torch.autograd.Function): bias_grad = torch.sum(output_grad, dim=tuple(range(len(output_grad.shape))[:-1])) bias_grad = all_reduce(bias_grad, ctx.input_parallel_mode) bias_grad, op = all_reduce(bias_grad, ctx.weight_parallel_mode, async_op=True) - async_ops.append(op) + bias_grad = push_async_grad(op, bias_grad, ctx.bias_id) else: bias_grad = None input_grad = torch.matmul(output_grad, weight) - for op in async_ops: - if op is not None: - op.wait() - - return input_grad, weight_grad, bias_grad, None, None, None, None, None, None + return input_grad, weight_grad, bias_grad, None, None, None, None, None -def classifier_3d(input_: Tensor, weight: Tensor, bias: Optional[Tensor], input_parallel_mode: ParallelMode, - weight_parallel_mode: ParallelMode, output_parallel_mode: ParallelMode) -> Tensor: +def classifier_3d( + input_: Tensor, + weight: Tensor, + bias: Optional[Tensor], + input_parallel_mode: ParallelMode, + weight_parallel_mode: ParallelMode, + output_parallel_mode: ParallelMode, +) -> Tensor: r"""3D parallel classifier. Args: @@ -183,16 +174,134 @@ def classifier_3d(input_: Tensor, weight: Tensor, bias: Optional[Tensor], input_ The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found in `parallel_mode `_ """ - return _Classifier3D.apply(input_, weight, bias, input_parallel_mode, weight_parallel_mode, output_parallel_mode) + return _Classifier3D.apply( + input_, + weight, + bias, + id(weight), + id(bias) if bias is not None else None, + input_parallel_mode, + weight_parallel_mode, + output_parallel_mode, + ) + + +class _VocabParallelClassifier3D(torch.autograd.Function): + + @staticmethod + @custom_fwd(cast_inputs=torch.float16) + def forward( + ctx, + input_: Tensor, + weight: Tensor, + bias: Optional[Tensor], + weight_id: int, + bias_id: Optional[int], + input_parallel_mode: ParallelMode, + weight_parallel_mode: ParallelMode, + output_parallel_mode: ParallelMode, + ) -> Tensor: + ctx.use_bias = bias is not None + ctx.weight_id = weight_id + + input_ = all_gather(input_, 0, input_parallel_mode) + weight = all_gather(weight.transpose(0, 1), -1, weight_parallel_mode) + ctx.save_for_backward(input_, weight) + + output = torch.matmul(input_, weight) + output = reduce_scatter(output, 0, output_parallel_mode) + + if bias is not None: + ctx.bias_id = bias_id + output += bias + + ctx.input_parallel_mode = input_parallel_mode + ctx.weight_parallel_mode = weight_parallel_mode + ctx.output_parallel_mode = output_parallel_mode + return output + + @staticmethod + @custom_bwd + def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]: + input_, weight = ctx.saved_tensors + with torch.no_grad(): + output_grad = all_gather(output_grad, 0, ctx.output_parallel_mode) + + input_grad = torch.matmul(output_grad, weight.transpose(0, 1)) + input_grad, input_op = reduce_scatter(input_grad, 0, ctx.input_parallel_mode, async_op=True) + + weight_grad = torch.matmul( + input_.reshape(-1, input_.shape[-1]).transpose(0, 1), output_grad.reshape(-1, output_grad.shape[-1])) + weight_grad, op = reduce_scatter(weight_grad.transpose(0, 1), 0, ctx.weight_parallel_mode, async_op=True) + weight_grad = push_async_grad(op, weight_grad, ctx.weight_id) + + if ctx.use_bias: + bias_grad = torch.sum(output_grad, dim=tuple(range(len(output_grad.shape))[:-1])) + bias_grad, op = all_reduce(bias_grad, ctx.weight_parallel_mode, async_op=True) + bias_grad = push_async_grad(op, bias_grad, ctx.bias_id) + else: + bias_grad = None + + input_op.wait() + + return input_grad, weight_grad, bias_grad, None, None, None, None, None + + +def vocab_parallel_classifier_3d( + input_: Tensor, + weight: Tensor, + bias: Optional[Tensor], + input_parallel_mode: ParallelMode, + weight_parallel_mode: ParallelMode, + output_parallel_mode: ParallelMode, +) -> Tensor: + r"""3D vocab parallel classifier. + + Args: + input_ (:class:`torch.tensor`): input matrix. + weight (:class:`torch.tensor`): matrix of weight. + bias (:class:`torch.tensor`): matrix of bias. + input_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): input parallel mode. + weight_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): weight parallel mode. + output_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): output parallel mode. + + Note: + The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found + in `parallel_mode `_ + """ + return _VocabParallelClassifier3D.apply( + input_, + weight, + bias, + id(weight), + id(bias) if bias is not None else None, + input_parallel_mode, + weight_parallel_mode, + output_parallel_mode, + ) class _Layernorm3D(torch.autograd.Function): @staticmethod @custom_fwd(cast_inputs=torch.float32) - def forward(ctx, input_: Tensor, weight: Tensor, bias: Optional[Tensor], normalized_shape: int, eps: float, - input_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode, - output_parallel_mode: ParallelMode) -> Tensor: + def forward( + ctx, + input_: Tensor, + weight: Tensor, + bias: Tensor, + weight_id: int, + bias_id: int, + normalized_shape: int, + eps: float, + input_parallel_mode: ParallelMode, + weight_parallel_mode: ParallelMode, + output_parallel_mode: ParallelMode, + input_x_weight_parallel_mode: ParallelMode, + ) -> Tensor: + ctx.weight_id = weight_id + ctx.bias_id = bias_id + mean = all_reduce(torch.sum(input_, dim=-1, keepdim=True), output_parallel_mode) / normalized_shape mu = input_ - mean var = all_reduce(torch.sum(mu**2, dim=-1, keepdim=True), output_parallel_mode) / normalized_shape @@ -201,15 +310,13 @@ class _Layernorm3D(torch.autograd.Function): ctx.save_for_backward(mu, sigma, weight) z = mu / sigma - output = weight * z - if bias is not None: - output = output + bias + output = weight * z + bias - ctx.use_bias = bias is not None ctx.normalized_shape = normalized_shape ctx.input_parallel_mode = input_parallel_mode ctx.weight_parallel_mode = weight_parallel_mode ctx.output_parallel_mode = output_parallel_mode + ctx.input_x_weight_parallel_mode = input_x_weight_parallel_mode return output @@ -218,17 +325,14 @@ class _Layernorm3D(torch.autograd.Function): def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]: mu, sigma, weight = ctx.saved_tensors with torch.no_grad(): - weight_grad = output_grad * mu / sigma - if ctx.use_bias: - bias_grad = output_grad - weight_grad = torch.stack([bias_grad, weight_grad]).contiguous() - else: - bias_grad = None - weight_grad = torch.sum(weight_grad, dim=tuple(range(len(weight_grad.shape))[1:-1])) - weight_grad = all_reduce(weight_grad, ctx.weight_parallel_mode) - weight_grad = all_reduce(weight_grad, ctx.input_parallel_mode) - if ctx.use_bias: - bias_grad, weight_grad = weight_grad[0], weight_grad[1] + + bias_grad, weight_grad = output_grad, output_grad * mu / sigma + bias_grad = torch.sum(bias_grad, dim=tuple(range(len(bias_grad.shape))[:-1])) + bias_grad, op = all_reduce(bias_grad, ctx.input_x_weight_parallel_mode, async_op=True) + bias_grad = push_async_grad(op, bias_grad, ctx.bias_id) + weight_grad = torch.sum(weight_grad, dim=tuple(range(len(weight_grad.shape))[:-1])) + weight_grad, op = all_reduce(weight_grad, ctx.input_x_weight_parallel_mode, async_op=True) + weight_grad = push_async_grad(op, weight_grad, ctx.weight_id) dz = output_grad * weight dvar = dz * mu * (-0.5) * sigma**(-3) @@ -236,15 +340,22 @@ class _Layernorm3D(torch.autograd.Function): dmean = dz * (-1 / sigma) + dvar * -2 * mu / ctx.normalized_shape dmean = all_reduce(torch.sum(dmean, dim=-1, keepdim=True), ctx.output_parallel_mode) - input_grad = dz / sigma + dvar * 2 * mu / \ - ctx.normalized_shape + dmean / ctx.normalized_shape + input_grad = dz / sigma + dvar * 2 * mu / ctx.normalized_shape + dmean / ctx.normalized_shape - return input_grad, weight_grad, bias_grad, None, None, None, None, None + return input_grad, weight_grad, bias_grad, None, None, None, None, None, None, None, None -def layernorm_3d(input_: Tensor, weight: Tensor, bias: Optional[Tensor], normalized_shape: int, eps: float, - input_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode, - output_parallel_mode: ParallelMode) -> Tensor: +def layernorm_3d( + input_: Tensor, + weight: Tensor, + bias: Tensor, + normalized_shape: int, + eps: float, + input_parallel_mode: ParallelMode, + weight_parallel_mode: ParallelMode, + output_parallel_mode: ParallelMode, + input_x_weight_parallel_mode: ParallelMode, +) -> Tensor: r"""3D parallel Layernorm. Args: @@ -265,8 +376,19 @@ def layernorm_3d(input_: Tensor, weight: Tensor, bias: Optional[Tensor], normali The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found in `parallel_mode `_ """ - return _Layernorm3D.apply(input_, weight, bias, normalized_shape, eps, input_parallel_mode, weight_parallel_mode, - output_parallel_mode) + return _Layernorm3D.apply( + input_, + weight, + bias, + id(weight), + id(bias), + normalized_shape, + eps, + input_parallel_mode, + weight_parallel_mode, + output_parallel_mode, + input_x_weight_parallel_mode, + ) def split_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor: @@ -315,17 +437,12 @@ def split_batch_3d(input_: Tensor, The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found in `parallel_mode `_. """ - dim_size = input_.size(dim) + if input_.size(dim) <= 1: + return input_ weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) weight_world_size = gpc.get_world_size(weight_parallel_mode) input_world_size = gpc.get_world_size(input_parallel_mode) - - assert dim_size % (input_world_size*weight_world_size) == 0, \ - f'The batch size ({dim_size}) is not a multiple of square of 3D depth ({input_world_size*weight_world_size}).' - - if input_.size(dim) <= 1: - return input_ output = torch.chunk(input_, weight_world_size, dim=dim)[gpc.get_local_rank(weight_parallel_mode)].contiguous() output = torch.chunk(output, input_world_size, dim=dim)[gpc.get_local_rank(input_parallel_mode)].contiguous() return output @@ -464,47 +581,3 @@ def reduce_by_batch_3d(tensor: Tensor, in `parallel_mode `_ """ return _ReduceByBatch3D.apply(tensor, input_parallel_mode, weight_parallel_mode, reduce_mean) - - -class _BroadcastWeight3D_FromDiagonal(torch.autograd.Function): - r"""broadcast weight from diagonal. - - Args: - input_ (:class:`torch.tensor`): input matrix. - input_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): input parallel mode. - weight_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): weight parallel mode. - output_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): output parallel mode. - - Note: - The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found - in `parallel_mode `_ - """ - - @staticmethod - @custom_fwd(cast_inputs=torch.float16) - def forward(ctx, input_: Tensor, input_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode, - output_parallel_mode: ParallelMode) -> Tensor: - ranks_in_group = gpc.get_ranks_in_group(input_parallel_mode) - src_rank = ranks_in_group[gpc.get_local_rank(output_parallel_mode)] - output = broadcast(input_, src_rank, input_parallel_mode) - ctx.src_rank = src_rank - ctx.input_parallel_mode = input_parallel_mode - ctx.weight_parallel_mode = weight_parallel_mode - ctx.output_parallel_mode = output_parallel_mode - return output - - @staticmethod - @custom_bwd - def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]: - input_grad = reduce(output_grad, ctx.src_rank, ctx.input_parallel_mode) - if gpc.get_local_rank(ctx.input_parallel_mode) == gpc.get_local_rank(ctx.output_parallel_mode): - input_grad = all_reduce(input_grad, ctx.weight_parallel_mode) - else: - input_grad = None - return input_grad, None, None, None - - -def broadcast_weight_3d_from_diagonal(tensor: Tensor, input_parallel_mode: ParallelMode, - weight_parallel_mode: ParallelMode, output_parallel_mode: ParallelMode) -> Tensor: - return _BroadcastWeight3D_FromDiagonal.apply(tensor, input_parallel_mode, weight_parallel_mode, - output_parallel_mode) diff --git a/colossalai/nn/layer/parallel_3d/_utils.py b/colossalai/nn/layer/parallel_3d/_utils.py index 0622164cd..759810f5e 100644 --- a/colossalai/nn/layer/parallel_3d/_utils.py +++ b/colossalai/nn/layer/parallel_3d/_utils.py @@ -1,8 +1,13 @@ -from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D, OUTPUT_GROUP_3D +from collections import OrderedDict +from functools import partial + +import torch +from torch import Tensor + +from colossalai.constants import INPUT_GROUP_3D, INPUT_X_WEIGHT_3D, OUTPUT_GROUP_3D, OUTPUT_X_WEIGHT_3D, WEIGHT_GROUP_3D from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc from colossalai.global_variables import tensor_parallel_env as env -from torch import Tensor def get_depth_from_env() -> int: @@ -17,30 +22,17 @@ def get_depth_from_env() -> int: def get_parallel_mode_from_env(group): - assert group in [INPUT_GROUP_3D, WEIGHT_GROUP_3D, OUTPUT_GROUP_3D], \ + assert group in [INPUT_GROUP_3D, WEIGHT_GROUP_3D, OUTPUT_GROUP_3D, INPUT_X_WEIGHT_3D, OUTPUT_X_WEIGHT_3D], \ f'{group} is not valid for 3D tensor parallelism.' return getattr(env, group) -def get_last_group(a, b): - mapping = { - ParallelMode.PARALLEL_3D_INPUT: 'A', - ParallelMode.PARALLEL_3D_WEIGHT: 'B', - ParallelMode.PARALLEL_3D_OUTPUT: 'C', - } - - res = chr(ord('A') + ord('B') + ord('C') - ord(mapping[a]) - ord(mapping[b])) - - if res == 'A': - return ParallelMode.PARALLEL_3D_INPUT - elif res == 'B': - return ParallelMode.PARALLEL_3D_WEIGHT - elif res == 'C': - return ParallelMode.PARALLEL_3D_OUTPUT - - def swap_in_out_group(): env.input_group_3d, env.output_group_3d = env.output_group_3d, env.input_group_3d + env.input_x_weight_group_3d, env.output_x_weight_group_3d = ( + env.output_x_weight_group_3d, + env.input_x_weight_group_3d, + ) def dbg_check_shape(tensor: Tensor, shape: tuple): @@ -49,3 +41,60 @@ def dbg_check_shape(tensor: Tensor, shape: tuple): print(tensor.shape) assert tensor.shape == shape, \ '{} does not match {}'.format(tensor.shape, shape) + + +class AsyncGradientBucket(object): + + def __init__(self): + self.bucket = OrderedDict() + + def __len__(self): + return len(self.bucket) + + def push(self, async_op, grad_tensor, param_id): + self.bucket[param_id] = tuple((async_op, grad_tensor)) + return torch.zeros_like(grad_tensor, dtype=grad_tensor.dtype, device=grad_tensor.device) + + def pop(self, param_id): + grad = None + if param_id in self.bucket: + op, grad = self.bucket.pop(param_id) + if op is not None: + op.wait() + return grad + + def synchronize(self, params): + for p in params: + i = id(p) + if i in self.bucket: + op, grad = self.bucket.pop(i) + if op is not None: + op.wait() + p.grad.add_(grad) + + +_async_grad_bucket = AsyncGradientBucket() + + +def push_async_grad(op, grad, param_id): + return _async_grad_bucket.push(op, grad, param_id) + + +def pop_async_grad(param_id): + return _async_grad_bucket.pop(param_id) + + +def _async_grad_hook(grad, param_id): + grad.add_(pop_async_grad(param_id)) + return grad + + +def register_async_grad_hook(param): + param.register_hook(partial(_async_grad_hook, param_id=id(param))) + + +def synchronize(params=list()): + _async_grad_bucket.synchronize(params) + torch.cuda.default_stream().synchronize() + if len(_async_grad_bucket) > 0: + raise RuntimeError(f"{len(_async_grad_bucket)} asynchronous gradient(s) not collected.") diff --git a/colossalai/nn/layer/parallel_3d/layers.py b/colossalai/nn/layer/parallel_3d/layers.py index 037a09763..6b3a7f4cc 100644 --- a/colossalai/nn/layer/parallel_3d/layers.py +++ b/colossalai/nn/layer/parallel_3d/layers.py @@ -6,7 +6,7 @@ import torch import torch.nn as nn import torch.nn.functional as F from colossalai.communication import all_reduce, broadcast -from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D +from colossalai.constants import INPUT_GROUP_3D, INPUT_X_WEIGHT_3D, OUTPUT_GROUP_3D, OUTPUT_X_WEIGHT_3D, WEIGHT_GROUP_3D from colossalai.context import ParallelMode, seed from colossalai.core import global_context as gpc from colossalai.global_variables import tensor_parallel_env as env @@ -20,9 +20,9 @@ from torch import Tensor from torch.nn import Parameter from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple -from ._operation import (all_gather_tensor_3d, broadcast_weight_3d_from_diagonal, classifier_3d, layernorm_3d, - linear_3d, reduce_scatter_tensor_3d, split_tensor_3d) -from ._utils import get_depth_from_env, get_last_group, get_parallel_mode_from_env, swap_in_out_group +from ._operation import (all_gather_tensor_3d, classifier_3d, vocab_parallel_classifier_3d, layernorm_3d, linear_3d, + reduce_scatter_tensor_3d, split_tensor_3d, split_batch_3d) +from ._utils import get_depth_from_env, get_parallel_mode_from_env, swap_in_out_group, register_async_grad_hook @LAYERS.register_module @@ -45,7 +45,8 @@ class LayerNorm3D(ParallelLayer): super().__init__() self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) - self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode) + self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) + self.input_x_weight_parallel_mode = get_parallel_mode_from_env(INPUT_X_WEIGHT_3D) self.depth = get_depth_from_env() self.normalized_shape = normalized_shape self.normalized_shape_per_partition = divide(normalized_shape, self.depth) @@ -58,6 +59,7 @@ class LayerNorm3D(ParallelLayer): else: self.bias = None self.variance_epsilon = eps + self.reset_parameters() self._set_tensor_parallel_attributes() def _set_tensor_parallel_attributes(self) -> None: @@ -67,8 +69,10 @@ class LayerNorm3D(ParallelLayer): def reset_parameters(self) -> None: init.ones_()(self.weight) + register_async_grad_hook(self.weight) if self.bias is not None: init.zeros_()(self.bias) + register_async_grad_hook(self.bias) def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() @@ -134,8 +138,17 @@ class LayerNorm3D(ParallelLayer): destination.update(local_state) def forward(self, input_: Tensor) -> Tensor: - return layernorm_3d(input_, self.weight, self.bias, self.normalized_shape, self.variance_epsilon, - self.input_parallel_mode, self.weight_parallel_mode, self.output_parallel_mode) + return layernorm_3d( + input_, + self.weight, + self.bias, + self.normalized_shape, + self.variance_epsilon, + self.input_parallel_mode, + self.weight_parallel_mode, + self.output_parallel_mode, + self.input_x_weight_parallel_mode, + ) @LAYERS.register_module @@ -161,6 +174,7 @@ class Linear3D(ParallelLayer): out_features: int, bias: bool = True, dtype: torch.dtype = None, + skip_bias_add: bool = False, weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): super().__init__() @@ -168,8 +182,10 @@ class Linear3D(ParallelLayer): self.out_features = out_features self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) - self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode) + self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) + self.output_x_weight_parallel_mode = get_parallel_mode_from_env(OUTPUT_X_WEIGHT_3D) self.depth = get_depth_from_env() + self.skip_bias_add = skip_bias_add self.in_features_per_partition = divide(in_features, self.depth) self.out_features_per_partition = divide(out_features, self.depth**2) self.bias_features_per_partition = divide(out_features, self.depth) @@ -194,18 +210,23 @@ class Linear3D(ParallelLayer): if self.bias is not None: set_tensor_parallel_attribute_by_partition(self.bias, self.depth) + def _sync_grad_hook(self, grad) -> Tensor: + grad = all_reduce(grad.clone(), self.output_x_weight_parallel_mode) + return grad + def reset_parameters(self, weight_initializer, bias_initializer) -> None: with seed(ParallelMode.TENSOR): fan_in, fan_out = self.in_features, self.out_features weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + register_async_grad_hook(self.weight) if self.bias is not None: bias_initializer(self.bias, fan_in=fan_in) - weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0] - output_src_rank = gpc.get_ranks_in_group(self.output_parallel_mode)[0] - broadcast(self.bias, weight_src_rank, self.weight_parallel_mode) - broadcast(self.bias, output_src_rank, self.output_parallel_mode) + broadcast(self.bias, + gpc.get_ranks_in_group(self.output_x_weight_parallel_mode)[0], + self.output_x_weight_parallel_mode) + self.bias.register_hook(self._sync_grad_hook) def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() @@ -324,8 +345,20 @@ class Linear3D(ParallelLayer): destination.update(local_state) def forward(self, input_: Tensor) -> Tensor: - return linear_3d(input_, self.weight, self.bias, self.input_parallel_mode, self.weight_parallel_mode, - self.output_parallel_mode) + output = linear_3d( + input_, + self.weight, + self.input_parallel_mode, + self.weight_parallel_mode, + self.output_parallel_mode, + ) + + if not self.skip_bias_add: + if self.bias is not None: + output = output + self.bias + return output + else: + return output, self.bias @LAYERS.register_module @@ -360,7 +393,7 @@ class Classifier3D(ParallelLayer): self.num_classes = num_classes self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) - self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode) + self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) self.depth = get_depth_from_env() self.in_features_per_partition = divide(in_features, self.depth) @@ -386,19 +419,17 @@ class Classifier3D(ParallelLayer): def reset_parameters(self, weight_initializer, bias_initializer) -> None: with seed(ParallelMode.TENSOR): fan_in, fan_out = self.in_features, self.num_classes - weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0] - output_src_rank = gpc.get_ranks_in_group(self.output_parallel_mode)[0] - input_src_rank = gpc.get_ranks_in_group(self.input_parallel_mode)[0] if self.has_weight: weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) - broadcast(self.weight, weight_src_rank, self.weight_parallel_mode) + broadcast(self.weight, gpc.get_ranks_in_group(self.weight_parallel_mode)[0], self.weight_parallel_mode) + + register_async_grad_hook(self.weight) if self.bias is not None: bias_initializer(self.bias, fan_in=fan_in) - broadcast(self.bias, weight_src_rank, self.weight_parallel_mode) - broadcast(self.bias, output_src_rank, self.output_parallel_mode) - broadcast(self.bias, input_src_rank, self.input_parallel_mode) + broadcast(self.bias, gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], ParallelMode.TENSOR) + register_async_grad_hook(self.bias) def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() @@ -468,8 +499,14 @@ class Classifier3D(ParallelLayer): destination.update(local_state) def forward(self, input_: Tensor) -> Tensor: - return classifier_3d(input_, self.weight, self.bias, self.input_parallel_mode, self.weight_parallel_mode, - self.output_parallel_mode) + return classifier_3d( + input_, + self.weight, + self.bias, + self.input_parallel_mode, + self.weight_parallel_mode, + self.output_parallel_mode, + ) @LAYERS.register_module @@ -504,7 +541,8 @@ class VocabParallelClassifier3D(ParallelLayer): self.num_classes = num_classes self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) - self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode) + self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) + self.output_x_weight_parallel_mode = get_parallel_mode_from_env(OUTPUT_X_WEIGHT_3D) self.depth = get_depth_from_env() self.in_features_per_partition = divide(in_features, self.depth) self.out_features_per_partition = divide(num_classes, self.depth**2) @@ -544,12 +582,14 @@ class VocabParallelClassifier3D(ParallelLayer): if self.has_weight: weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + register_async_grad_hook(self.weight) + if self.bias is not None: bias_initializer(self.bias, fan_in=fan_in) - weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0] - output_src_rank = gpc.get_ranks_in_group(self.output_parallel_mode)[0] - broadcast(self.bias, weight_src_rank, self.weight_parallel_mode) - broadcast(self.bias, output_src_rank, self.output_parallel_mode) + broadcast(self.bias, + gpc.get_ranks_in_group(self.output_x_weight_parallel_mode)[0], + self.output_x_weight_parallel_mode) + register_async_grad_hook(self.bias) def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() @@ -668,8 +708,14 @@ class VocabParallelClassifier3D(ParallelLayer): destination.update(local_state) def forward(self, input_: Tensor) -> Tensor: - return linear_3d(input_, self.weight.transpose(0, 1), self.bias, self.input_parallel_mode, - self.weight_parallel_mode, self.output_parallel_mode) + return vocab_parallel_classifier_3d( + input_, + self.weight, + self.bias, + self.input_parallel_mode, + self.weight_parallel_mode, + self.output_parallel_mode, + ) @LAYERS.register_module @@ -708,12 +754,16 @@ class PatchEmbedding3D(ParallelLayer): self.depth = get_depth_from_env() self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) - self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode) - self.patch_size = to_2tuple(patch_size) - grid_size = to_2tuple(img_size // patch_size) - num_patches = grid_size[0] * grid_size[1] + self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) + self.input_x_weight_parallel_mode = get_parallel_mode_from_env(INPUT_X_WEIGHT_3D) + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) + self.num_patches = self.grid_size[0] * self.grid_size[1] self.embed_size = embed_size - embed_size_per_partition = divide(embed_size, self.depth) + embed_size_per_partition = embed_size // self.depth self.flatten = flatten self.weight = nn.Parameter( @@ -725,7 +775,7 @@ class PatchEmbedding3D(ParallelLayer): self.cls_token = nn.Parameter( torch.zeros((1, 1, embed_size_per_partition), device=get_current_device(), dtype=dtype)) self.pos_embed = nn.Parameter( - torch.zeros((1, num_patches + 1, embed_size_per_partition), device=get_current_device(), dtype=dtype)) + torch.zeros((1, self.num_patches + 1, embed_size_per_partition), device=get_current_device(), dtype=dtype)) self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer) self._set_tensor_parallel_attributes() @@ -737,8 +787,7 @@ class PatchEmbedding3D(ParallelLayer): set_tensor_parallel_attribute_by_partition(self.pos_embed, self.depth) def _sync_grad_hook(self, grad) -> Tensor: - grad = all_reduce(grad.clone(), self.input_parallel_mode) - grad = all_reduce(grad, self.weight_parallel_mode) + grad = all_reduce(grad.clone(), self.input_x_weight_parallel_mode) return grad def reset_parameters(self, weight_initializer, bias_initializer, position_embed_initializer) -> None: @@ -749,14 +798,10 @@ class PatchEmbedding3D(ParallelLayer): bias_initializer(self.bias, fan_in=fan_in) position_embed_initializer(self.pos_embed) - weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0] - input_src_rank = gpc.get_ranks_in_group(self.input_parallel_mode)[0] - broadcast(self.weight, weight_src_rank, self.weight_parallel_mode) - broadcast(self.bias, weight_src_rank, self.weight_parallel_mode) - broadcast(self.pos_embed, weight_src_rank, self.weight_parallel_mode) - broadcast(self.weight, input_src_rank, self.input_parallel_mode) - broadcast(self.bias, input_src_rank, self.input_parallel_mode) - broadcast(self.pos_embed, input_src_rank, self.input_parallel_mode) + src_rank = gpc.get_ranks_in_group(self.input_x_weight_parallel_mode)[0] + broadcast(self.weight, src_rank, self.input_x_weight_parallel_mode) + broadcast(self.bias, src_rank, self.input_x_weight_parallel_mode) + broadcast(self.pos_embed, src_rank, self.input_x_weight_parallel_mode) self.weight.register_hook(self._sync_grad_hook) self.bias.register_hook(self._sync_grad_hook) @@ -850,11 +895,12 @@ class PatchEmbedding3D(ParallelLayer): destination.update(local_state) def forward(self, input_: Tensor) -> Tensor: - input_ = split_tensor_3d(input_, 0, self.weight_parallel_mode) - input_ = split_tensor_3d(input_, 0, self.input_parallel_mode) + input_ = split_batch_3d(input_, + input_parallel_mode=self.input_parallel_mode, + weight_parallel_mode=self.weight_parallel_mode) output = F.conv2d(input_, self.weight, self.bias, stride=self.patch_size) if self.flatten: - output = output.flatten(2).transpose(1, 2) # BCHW -> BNC + output = output.flatten(2).transpose(1, 2) # BCHW -> BNC cls_token = self.cls_token.expand(output.shape[0], -1, -1) output = torch.cat((cls_token, output), dim=1) @@ -906,7 +952,8 @@ class Embedding3D(ParallelLayer): self.depth = get_depth_from_env() self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) - self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode) + self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) + self.input_x_weight_parallel_mode = get_parallel_mode_from_env(INPUT_X_WEIGHT_3D) self.num_embeddings = num_embeddings self.embed_dim = embedding_dim @@ -924,13 +971,18 @@ class Embedding3D(ParallelLayer): def _set_tensor_parallel_attributes(self) -> None: set_tensor_parallel_attribute_by_partition(self.weight, self.depth) + def _sync_grad_hook(self, grad) -> Tensor: + grad = all_reduce(grad.clone(), self.input_x_weight_parallel_mode) + return grad + def reset_parameters(self, weight_initializer) -> None: with seed(ParallelMode.TENSOR): fan_in, fan_out = self.num_embeddings, self.embed_dim weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) self._fill_padding_idx_with_zero() - weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0] - broadcast(self.weight, weight_src_rank, self.weight_parallel_mode) + broadcast(self.weight, + gpc.get_ranks_in_group(self.input_x_weight_parallel_mode)[0], self.input_x_weight_parallel_mode) + self.weight.register_hook(self._sync_grad_hook) def _fill_padding_idx_with_zero(self) -> None: if self.padding_idx is not None: @@ -981,11 +1033,10 @@ class Embedding3D(ParallelLayer): destination.update(local_state) def forward(self, input_: Tensor) -> Tensor: - input_ = split_tensor_3d(input_, 0, self.weight_parallel_mode) - input_ = split_tensor_3d(input_, 0, self.input_parallel_mode) - weight = broadcast_weight_3d_from_diagonal(self.weight, self.input_parallel_mode, self.weight_parallel_mode, - self.output_parallel_mode) - output = F.embedding(input_, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) + input_ = split_batch_3d(input_, + input_parallel_mode=self.input_parallel_mode, + weight_parallel_mode=self.weight_parallel_mode) + output = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) return output @@ -1039,7 +1090,7 @@ class VocabParallelEmbedding3D(ParallelLayer): self.depth = get_depth_from_env() self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) - self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode) + self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D) self.num_embeddings_per_partition = divide(self.num_embeddings, self.depth**2) self.embed_dim_per_partition = divide(self.embed_dim, self.depth) vocab_parallel_rank = gpc.get_local_rank(self.input_parallel_mode) diff --git a/docker/Dockerfile b/docker/Dockerfile index 4b55dc1eb..bcb7c0fff 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -6,12 +6,12 @@ RUN conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch # install apex RUN git clone https://github.com/NVIDIA/apex && \ cd apex && \ - pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./ + pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" --global-option="--fast_layer_norm" ./ # install colossalai RUN git clone https://github.com/hpcaitech/ColossalAI.git \ - && cd ./ColossalAI \ - && pip install -v --no-cache-dir . + && cd ./ColossalAI \ + && pip install -v --no-cache-dir . # install titans RUN pip install --no-cache-dir titans diff --git a/tests/test_layers/test_3d/checks_3d/check_layer_3d.py b/tests/test_layers/test_3d/checks_3d/check_layer_3d.py index d398c4365..9e199e22e 100644 --- a/tests/test_layers/test_3d/checks_3d/check_layer_3d.py +++ b/tests/test_layers/test_3d/checks_3d/check_layer_3d.py @@ -20,7 +20,6 @@ def check_linear(): rank = torch.distributed.get_rank() logger = get_dist_logger() device = get_current_device() - dtype = torch.float32 INPUT_SIZE = HIDDEN_SIZE OUTPUT_SIZE = 2 * HIDDEN_SIZE @@ -32,12 +31,12 @@ def check_linear(): i = global_context.get_local_rank(weight_parallel_mode) k = global_context.get_local_rank(output_parallel_mode) - layer = Linear3D(INPUT_SIZE, OUTPUT_SIZE, dtype=dtype, bias=True) + layer = Linear3D(INPUT_SIZE, OUTPUT_SIZE, bias=True) layer = layer.to(device) layer_master = torch.nn.Linear(INPUT_SIZE, OUTPUT_SIZE) layer_master = layer_master.to(device) - weight_master = layer_master.weight.data.transpose(0, 1) + weight_master = layer_master.weight.data.transpose(0, 1).contiguous() torch.distributed.broadcast(weight_master, src=0) weight = torch.chunk(weight_master, DEPTH, dim=0)[k] weight = torch.chunk(weight, DEPTH, dim=-1)[j] @@ -49,7 +48,7 @@ def check_linear(): layer.bias.data.copy_(bias) A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE) - A_master = torch.randn(A_shape, dtype=dtype, device=device) + A_master = torch.randn(A_shape, device=device) torch.distributed.broadcast(A_master, src=0) A = torch.chunk(A_master, DEPTH, dim=0)[i] A = torch.chunk(A, DEPTH, dim=-1)[k] @@ -72,7 +71,7 @@ def check_linear(): logger.info('Rank {} linear forward: {}'.format(rank, check_equal(out, C))) grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) + grad_master = torch.randn(grad_shape, device=get_current_device()) torch.distributed.broadcast(grad_master, src=0) grad = torch.chunk(grad_master, DEPTH, dim=0)[i] grad = torch.chunk(grad, DEPTH, dim=-1)[j] @@ -108,7 +107,6 @@ def check_layernorm(): rank = torch.distributed.get_rank() logger = get_dist_logger() device = get_current_device() - dtype = torch.float32 INPUT_SIZE = HIDDEN_SIZE input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) @@ -119,7 +117,7 @@ def check_layernorm(): i = global_context.get_local_rank(weight_parallel_mode) k = global_context.get_local_rank(output_parallel_mode) - norm = LayerNorm3D(INPUT_SIZE, eps=1e-6, dtype=dtype) + norm = LayerNorm3D(INPUT_SIZE, eps=1e-6) norm = norm.to(device) norm_master = torch.nn.LayerNorm(INPUT_SIZE, eps=1e-6) norm_master = norm_master.to(device) @@ -134,7 +132,7 @@ def check_layernorm(): norm.bias.data.copy_(bias) A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE) - A_master = torch.randn(A_shape, dtype=dtype, device=device) + A_master = torch.randn(A_shape, device=device) torch.distributed.broadcast(A_master, src=0) A = torch.chunk(A_master, DEPTH, dim=0)[i] A = torch.chunk(A, DEPTH, dim=-1)[k] @@ -159,7 +157,7 @@ def check_layernorm(): logger.info('Rank {} layernorm forward: {}'.format(rank, check_equal(out, C))) grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + grad_master = torch.randn(grad_shape, device=device) torch.distributed.broadcast(grad_master, src=0) grad = torch.chunk(grad_master, DEPTH, dim=0)[i] grad = torch.chunk(grad, DEPTH, dim=-1)[k] @@ -193,7 +191,6 @@ def check_classifier_no_given_weight(): rank = torch.distributed.get_rank() logger = get_dist_logger() device = get_current_device() - dtype = torch.float32 INPUT_SIZE = HIDDEN_SIZE input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) @@ -204,10 +201,10 @@ def check_classifier_no_given_weight(): i = global_context.get_local_rank(weight_parallel_mode) k = global_context.get_local_rank(output_parallel_mode) - layer = Classifier3D(INPUT_SIZE, NUM_CLASSES, dtype=dtype, bias=True) + layer = Classifier3D(INPUT_SIZE, NUM_CLASSES, bias=True) layer = layer.to(device) - layer_master = VanillaClassifier(INPUT_SIZE, NUM_CLASSES, bias=True, dtype=dtype) + layer_master = VanillaClassifier(INPUT_SIZE, NUM_CLASSES, bias=True) layer_master = layer_master.to(device) weight_master = layer_master.weight.data @@ -219,7 +216,7 @@ def check_classifier_no_given_weight(): layer.bias.data.copy_(bias_master) A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE) - A_master = torch.randn(A_shape, dtype=dtype, device=device) + A_master = torch.randn(A_shape, device=device) torch.distributed.broadcast(A_master, src=0) A = torch.chunk(A_master, DEPTH, dim=0)[i] A = torch.chunk(A, DEPTH, dim=-1)[k] @@ -242,7 +239,7 @@ def check_classifier_no_given_weight(): logger.info('Rank {} classifier (no given weight) forward: {}'.format(rank, check_equal(out, C))) grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) + grad_master = torch.randn(grad_shape, device=get_current_device()) torch.distributed.broadcast(grad_master, src=0) grad = torch.chunk(grad_master, DEPTH, dim=0)[i] grad = torch.chunk(grad, DEPTH, dim=0)[j] @@ -283,7 +280,6 @@ def check_vocab_parallel_classifier_no_given_weight(): rank = torch.distributed.get_rank() logger = get_dist_logger() device = get_current_device() - dtype = torch.float32 INPUT_SIZE = HIDDEN_SIZE input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) @@ -295,10 +291,10 @@ def check_vocab_parallel_classifier_no_given_weight(): k = global_context.get_local_rank(output_parallel_mode) layer = VocabParallelClassifier3D(INPUT_SIZE, VOCAB_SIZE, bias=True) - layer = layer.to(dtype).to(device) + layer = layer.to(device) layer_master = VanillaClassifier(INPUT_SIZE, VOCAB_SIZE, bias=True) - layer_master = layer_master.to(dtype).to(device) + layer_master = layer_master.to(device) weight_master = layer_master.weight.data torch.distributed.broadcast(weight_master, src=0) @@ -312,7 +308,7 @@ def check_vocab_parallel_classifier_no_given_weight(): layer.bias.data.copy_(bias) A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE) - A_master = torch.randn(A_shape, dtype=dtype, device=device) + A_master = torch.randn(A_shape, device=device) torch.distributed.broadcast(A_master, src=0) A = torch.chunk(A_master, DEPTH, dim=0)[i] A = torch.chunk(A, DEPTH, dim=-1)[k] @@ -336,7 +332,7 @@ def check_vocab_parallel_classifier_no_given_weight(): logger.info('Rank {} vocab parallel classifier (no given weight) forward: {}'.format(rank, check_equal(out, C))) grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + grad_master = torch.randn(grad_shape, device=device) torch.distributed.broadcast(grad_master, src=0) grad = torch.chunk(grad_master, DEPTH, dim=0)[i] grad = torch.chunk(grad, DEPTH, dim=-1)[j] @@ -455,7 +451,6 @@ def check_vocab_parallel_classifier_given_embed_weight(): rank = torch.distributed.get_rank() logger = get_dist_logger() device = get_current_device() - dtype = torch.float32 input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) @@ -466,10 +461,10 @@ def check_vocab_parallel_classifier_given_embed_weight(): k = global_context.get_local_rank(output_parallel_mode) embed = VocabParallelEmbedding3D(VOCAB_SIZE, HIDDEN_SIZE) - embed = embed.to(dtype).to(device) + embed = embed.to(device) embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE) - embed_master = embed_master.to(dtype).to(device) + embed_master = embed_master.to(device) weight_master = embed_master.weight.data torch.distributed.broadcast(weight_master, src=0) @@ -479,10 +474,10 @@ def check_vocab_parallel_classifier_given_embed_weight(): embed.weight.data.copy_(weight) layer = VocabParallelClassifier3D(HIDDEN_SIZE, VOCAB_SIZE, weight=embed.weight, bias=False) - layer = layer.to(dtype).to(device) + layer = layer.to(device) layer_master = VanillaClassifier(HIDDEN_SIZE, VOCAB_SIZE, weight=embed_master.weight, bias=False) - layer_master = layer_master.to(dtype).to(device) + layer_master = layer_master.to(device) A_shape = (BATCH_SIZE, SEQ_LENGTH) A_master = torch.randint(VOCAB_SIZE, A_shape, device=device) @@ -504,7 +499,7 @@ def check_vocab_parallel_classifier_given_embed_weight(): logger.info('Rank {} vocab parallel classifier (given embed weight) forward: {}'.format(rank, check_equal(out, C))) grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + grad_master = torch.randn(grad_shape, device=device) torch.distributed.broadcast(grad_master, src=0) grad = torch.chunk(grad_master, DEPTH, dim=0)[i] grad = torch.chunk(grad, DEPTH, dim=-1)[j] @@ -546,12 +541,12 @@ def check_patch_embed(): i = global_context.get_local_rank(weight_parallel_mode) k = global_context.get_local_rank(output_parallel_mode) - layer = PatchEmbedding3D(IMG_SIZE, 4, 3, HIDDEN_SIZE, dtype=dtype) + layer = PatchEmbedding3D(IMG_SIZE, 4, 3, HIDDEN_SIZE) torch.nn.init.ones_(layer.cls_token) torch.nn.init.ones_(layer.pos_embed) layer = layer.to(device) - layer_master = VanillaPatchEmbedding(IMG_SIZE, 4, 3, HIDDEN_SIZE, dtype=dtype) + layer_master = VanillaPatchEmbedding(IMG_SIZE, 4, 3, HIDDEN_SIZE) torch.nn.init.ones_(layer_master.cls_token) torch.nn.init.ones_(layer_master.pos_embed) layer_master = layer_master.to(device) @@ -566,7 +561,7 @@ def check_patch_embed(): layer.bias.data.copy_(proj_bias) A_shape = (BATCH_SIZE, 3, IMG_SIZE, IMG_SIZE) - A_master = torch.randn(A_shape, dtype=dtype, device=device) + A_master = torch.randn(A_shape, device=device) torch.distributed.broadcast(A_master, src=0) A = A_master.clone() @@ -586,7 +581,7 @@ def check_patch_embed(): logger.info('Rank {} patch embed forward: {}'.format(rank, check_equal(out, C))) grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + grad_master = torch.randn(grad_shape, device=device) torch.distributed.broadcast(grad_master, src=0) grad = torch.chunk(grad_master, DEPTH, dim=0)[i] grad = torch.chunk(grad, DEPTH, dim=-1)[k] @@ -639,9 +634,9 @@ def check_embed(): k = global_context.get_local_rank(output_parallel_mode) layer = Embedding3D(VOCAB_SIZE, HIDDEN_SIZE) - layer = layer.to(dtype).to(device) + layer = layer.to(device) layer_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE) - layer_master = layer_master.to(dtype).to(device) + layer_master = layer_master.to(device) weight_master = layer_master.weight.data torch.distributed.broadcast(weight_master, src=0) @@ -669,7 +664,7 @@ def check_embed(): logger.info('Rank {} embed forward: {}'.format(rank, check_equal(out, C))) grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + grad_master = torch.randn(grad_shape, device=device) torch.distributed.broadcast(grad_master, src=0) grad = torch.chunk(grad_master, DEPTH, dim=0)[i] grad = torch.chunk(grad, DEPTH, dim=-1)[k] @@ -686,10 +681,7 @@ def check_embed(): B_grad = layer_master.weight.grad B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k] - if j == k: - logger.info('Rank {} embed backward (weight_grad): {}'.format(rank, check_equal(B_grad, layer.weight.grad))) - else: - logger.info('Rank {} embed backward (weight_grad): {}'.format(rank, layer.weight.grad is None)) + logger.info('Rank {} embed backward (weight_grad): {}'.format(rank, check_equal(B_grad, layer.weight.grad))) return fwd_end - fwd_start, bwd_end - bwd_start @@ -709,9 +701,9 @@ def check_vocab_parallel_embed(): k = global_context.get_local_rank(output_parallel_mode) layer = VocabParallelEmbedding3D(VOCAB_SIZE, HIDDEN_SIZE) - layer = layer.to(dtype).to(device) + layer = layer.to(device) layer_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE) - layer_master = layer_master.to(dtype).to(device) + layer_master = layer_master.to(device) weight_master = layer_master.weight.data torch.distributed.broadcast(weight_master, src=0) @@ -741,7 +733,7 @@ def check_vocab_parallel_embed(): logger.info('Rank {} vocab parallel embed forward: {}'.format(rank, check_equal(out, C))) grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + grad_master = torch.randn(grad_shape, device=device) torch.distributed.broadcast(grad_master, src=0) grad = torch.chunk(grad_master, DEPTH, dim=0)[i] grad = torch.chunk(grad, DEPTH, dim=-1)[k] @@ -771,7 +763,6 @@ def check_loss(): rank = torch.distributed.get_rank() logger = get_dist_logger() device = get_current_device() - dtype = torch.float32 input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) @@ -783,8 +774,8 @@ def check_loss(): criterion_master = torch.nn.CrossEntropyLoss() out_shape = (BATCH_SIZE, NUM_CLASSES) - out_master = torch.randn(out_shape, dtype=dtype, device=device) - target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE,), dtype=torch.long, device=device) + out_master = torch.randn(out_shape, device=device) + target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, ), dtype=torch.long, device=device) torch.distributed.broadcast(out_master, src=0) torch.distributed.broadcast(target_master, src=0) out = torch.chunk(out_master, DEPTH, dim=0)[i] @@ -836,8 +827,8 @@ def check_vocab_parallel_loss(): criterion_master = torch.nn.CrossEntropyLoss() out_shape = (BATCH_SIZE, NUM_CLASSES) - out_master = torch.randn(out_shape, dtype=dtype, device=device) - target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE,), dtype=torch.long, device=device) + out_master = torch.randn(out_shape, device=device) + target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, ), dtype=torch.long, device=device) torch.distributed.broadcast(out_master, src=0) torch.distributed.broadcast(target_master, src=0) out = torch.chunk(out_master, DEPTH, dim=0)[i] diff --git a/tests/test_layers/test_3d/checks_3d/common.py b/tests/test_layers/test_3d/checks_3d/common.py index 32ab63711..afb19c474 100644 --- a/tests/test_layers/test_3d/checks_3d/common.py +++ b/tests/test_layers/test_3d/checks_3d/common.py @@ -12,8 +12,8 @@ NUM_BLOCKS = 2 IMG_SIZE = 16 VOCAB_SIZE = 16 + def check_equal(A, B): eq = torch.allclose(A, B, rtol=1e-3, atol=1e-2) - assert eq - return eq - + assert eq, f"\nA = {A}\nB = {B}" + return eq \ No newline at end of file diff --git a/tests/test_layers/test_3d/test_3d.py b/tests/test_layers/test_3d/test_3d.py index c79dde2a1..29a8b3aea 100644 --- a/tests/test_layers/test_3d/test_3d.py +++ b/tests/test_layers/test_3d/test_3d.py @@ -10,9 +10,8 @@ from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers from colossalai.utils import free_port from colossalai.testing import rerun_if_address_is_in_use, skip_if_not_enough_gpus -from checks_3d.check_layer_3d import (check_classifier_given_embed_weight, check_classifier_no_given_weight, - check_embed, check_layernorm, check_linear, check_loss, check_patch_embed, - check_vocab_parallel_classifier_given_embed_weight, +from checks_3d.check_layer_3d import (check_classifier_no_given_weight, check_embed, check_layernorm, check_linear, + check_loss, check_patch_embed, check_vocab_parallel_classifier_given_embed_weight, check_vocab_parallel_classifier_no_given_weight, check_vocab_parallel_embed, check_vocab_parallel_loss) @@ -30,7 +29,6 @@ def check_layer(): check_layernorm() check_classifier_no_given_weight() check_vocab_parallel_classifier_no_given_weight() - check_classifier_given_embed_weight() check_vocab_parallel_classifier_given_embed_weight() check_embed() check_patch_embed()