diff --git a/colossalai/nn/layer/colossalai_layer/_utils.py b/colossalai/nn/layer/colossalai_layer/_utils.py index 6f23def9c..4283e5fe0 100644 --- a/colossalai/nn/layer/colossalai_layer/_utils.py +++ b/colossalai/nn/layer/colossalai_layer/_utils.py @@ -1,12 +1,12 @@ import torch.nn as nn from torch import Tensor -from ..parallel_2d._operation import split_tensor_2d -from ..parallel_2p5d._operation import split_tensor_2p5d +from ..parallel_2d._operation import split_batch_2d +from ..parallel_2p5d._operation import split_batch_2p5d from ..parallel_3d._operation import split_batch_3d from ..utils import get_tensor_parallel_mode -_parallel_split_batch = {'2d': split_tensor_2d, '2.5d': split_tensor_2p5d, '3d': split_batch_3d} +_parallel_split_batch = {'2d': split_batch_2d, '2.5d': split_batch_2p5d, '3d': split_batch_3d} def partition_batch(input_) -> Tensor: diff --git a/colossalai/nn/layer/parallel_2d/__init__.py b/colossalai/nn/layer/parallel_2d/__init__.py index 9bb62b456..5562d1a70 100644 --- a/colossalai/nn/layer/parallel_2d/__init__.py +++ b/colossalai/nn/layer/parallel_2d/__init__.py @@ -1,8 +1,8 @@ -from ._operation import reduce_by_batch_2d, split_tensor_2d +from ._operation import reduce_by_batch_2d, split_batch_2d from .layers import (Classifier2D, Embedding2D, LayerNorm2D, Linear2D, PatchEmbedding2D, VocabParallelClassifier2D, VocabParallelEmbedding2D) __all__ = [ - 'split_tensor_2d', 'reduce_by_batch_2d', 'Linear2D', 'LayerNorm2D', 'Classifier2D', 'PatchEmbedding2D', + 'split_batch_2d', 'reduce_by_batch_2d', 'Linear2D', 'LayerNorm2D', 'Classifier2D', 'PatchEmbedding2D', 'Embedding2D', 'VocabParallelEmbedding2D', 'VocabParallelClassifier2D' ] diff --git a/colossalai/nn/layer/parallel_2d/_operation.py b/colossalai/nn/layer/parallel_2d/_operation.py index ac6f00b27..d6fe58f1b 100644 --- a/colossalai/nn/layer/parallel_2d/_operation.py +++ b/colossalai/nn/layer/parallel_2d/_operation.py @@ -720,7 +720,7 @@ def all_gather_tensor_2d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) return _AllGatherTensor2D.apply(tensor, dim, parallel_mode) -def split_tensor_2d(input_: Tensor, dim: int = 0) -> Tensor: +def split_batch_2d(input_: Tensor, dim: int = 0) -> Tensor: """Splits 2D tensor in specified dimension across cols. Args: @@ -730,6 +730,11 @@ def split_tensor_2d(input_: Tensor, dim: int = 0) -> Tensor: Returns: :class:`torch.tensor`: The tensor has been split. """ + dim_size = input_.size(dim) + world_size = gpc.get_world_size(ParallelMode.PARALLEL_2D_COL) + assert dim_size % world_size == 0, \ + f'The batch size ({dim_size}) is not a multiple of 2D size ({world_size}).' + if input_.size(dim) <= 1: return input_ return torch.chunk(input_, gpc.get_world_size(ParallelMode.PARALLEL_2D_COL), @@ -784,6 +789,11 @@ def reduce_scatter_tensor_2d(tensor: Tensor, dim: int, parallel_mode: ParallelMo The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found in `parallel_mode `_ """ + dim_size = tensor.size(dim) + world_size = gpc.get_world_size(parallel_mode) + assert dim_size % world_size == 0, \ + f'The batch size ({dim_size}) is not a multiple of 2D size ({world_size}).' + return _ReduceScatterTensor2D.apply(tensor, dim, parallel_mode) diff --git a/colossalai/nn/layer/parallel_2d/layers.py b/colossalai/nn/layer/parallel_2d/layers.py index e2cc52801..1ba7768ac 100644 --- a/colossalai/nn/layer/parallel_2d/layers.py +++ b/colossalai/nn/layer/parallel_2d/layers.py @@ -19,7 +19,7 @@ from torch.nn import Parameter from ..base_layer import ParallelLayer from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple from ._operation import (Matmul_AB_2D, Matmul_ABT_2D, add_bias_2d, all_gather_tensor_2d, classifier_2d, layernorm_2d, - reduce_scatter_tensor_2d, split_tensor_2d) + reduce_scatter_tensor_2d, split_batch_2d) from ._utils import assert_summa_initialization, get_summa_dim_from_env @@ -547,7 +547,7 @@ class PatchEmbedding2D(ParallelLayer): destination.update(local_state) def forward(self, input_: Tensor) -> Tensor: - input_ = split_tensor_2d(input_) + input_ = split_batch_2d(input_) B, C, H, W = input_.shape assert H == self.img_size[0] and W == self.img_size[1], \ @@ -692,7 +692,7 @@ class Embedding2D(ParallelLayer): destination.update(local_state) def forward(self, input_: Tensor) -> Tensor: - input_ = split_tensor_2d(input_) + input_ = split_batch_2d(input_) weight = all_gather_tensor_2d(self.weight, -1, ParallelMode.PARALLEL_2D_COL) output = F.embedding(input_, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) diff --git a/colossalai/nn/layer/parallel_2p5d/__init__.py b/colossalai/nn/layer/parallel_2p5d/__init__.py index 5ca351605..bec3b1c4b 100644 --- a/colossalai/nn/layer/parallel_2p5d/__init__.py +++ b/colossalai/nn/layer/parallel_2p5d/__init__.py @@ -1,8 +1,8 @@ -from ._operation import reduce_by_batch_2p5d, split_tensor_2p5d +from ._operation import reduce_by_batch_2p5d, split_batch_2p5d from .layers import (Classifier2p5D, Embedding2p5D, LayerNorm2p5D, Linear2p5D, PatchEmbedding2p5D, VocabParallelClassifier2p5D, VocabParallelEmbedding2p5D) __all__ = [ - 'split_tensor_2p5d', 'reduce_by_batch_2p5d', 'Linear2p5D', 'LayerNorm2p5D', 'Classifier2p5D', 'PatchEmbedding2p5D', + 'split_batch_2p5d', 'reduce_by_batch_2p5d', 'Linear2p5D', 'LayerNorm2p5D', 'Classifier2p5D', 'PatchEmbedding2p5D', 'Embedding2p5D', 'VocabParallelClassifier2p5D', 'VocabParallelEmbedding2p5D' ] diff --git a/colossalai/nn/layer/parallel_2p5d/_operation.py b/colossalai/nn/layer/parallel_2p5d/_operation.py index 536ba8305..38f6bba72 100644 --- a/colossalai/nn/layer/parallel_2p5d/_operation.py +++ b/colossalai/nn/layer/parallel_2p5d/_operation.py @@ -755,7 +755,7 @@ class SplitFirst(torch.autograd.Function): return grad, None, None -def split_tensor_2p5d(input_: Tensor, dim: int = 0) -> Tensor: +def split_batch_2p5d(input_: Tensor, dim: int = 0) -> Tensor: """Splits 2P5D tensor in specified dimension across cols. Args: @@ -765,6 +765,11 @@ def split_tensor_2p5d(input_: Tensor, dim: int = 0) -> Tensor: Returns: :class:`torch.tensor`: The tensor has been split. """ + dim_size = input_.size(dim) + world_size = gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL) + assert dim_size % world_size == 0, \ + f'The batch size ({dim_size}) is not a multiple of 2.5D size * depth ({world_size}).' + if input_.size(dim) <= 1: return input_ return torch.chunk(input_, gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL), @@ -819,6 +824,11 @@ def reduce_scatter_tensor_2p5d(input_: Tensor, dim: int, parallel_mode: Parallel The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found in `parallel_mode `_ """ + dim_size = input_.size(dim) + world_size = gpc.get_world_size(parallel_mode) + assert dim_size % world_size == 0, \ + f'The batch size ({dim_size}) is not a multiple of 2.5D size * depth ({world_size}).' + return _ReduceScatterTensor2p5D.apply(input_, dim, parallel_mode) diff --git a/colossalai/nn/layer/parallel_2p5d/layers.py b/colossalai/nn/layer/parallel_2p5d/layers.py index a3120c0c1..dd188885f 100644 --- a/colossalai/nn/layer/parallel_2p5d/layers.py +++ b/colossalai/nn/layer/parallel_2p5d/layers.py @@ -20,7 +20,7 @@ from torch.nn import Parameter from ..base_layer import ParallelLayer from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple from ._operation import (Matmul_AB_2p5D, Matmul_ABT_2p5D, add_bias_2p5d, all_gather_tensor_2p5d, classifier_2p5d, - layernorm_2p5d, reduce_scatter_tensor_2p5d, split_tensor_2p5d) + layernorm_2p5d, reduce_scatter_tensor_2p5d, split_batch_2p5d) from ._utils import assert_tesseract_initialization, get_tesseract_dim_dep_from_env @@ -568,7 +568,7 @@ class PatchEmbedding2p5D(ParallelLayer): destination.update(local_state) def forward(self, input_: Tensor) -> Tensor: - input_ = split_tensor_2p5d(input_, 0) + input_ = split_batch_2p5d(input_, 0) B, C, H, W = input_.shape assert H == self.img_size[0] and W == self.img_size[1], \ @@ -713,7 +713,7 @@ class Embedding2p5D(ParallelLayer): destination.update(local_state) def forward(self, input_: Tensor) -> Tensor: - input_ = split_tensor_2p5d(input_, 0) + input_ = split_batch_2p5d(input_, 0) weight = all_gather_tensor_2p5d(self.weight, -1, ParallelMode.PARALLEL_2P5D_COL) diff --git a/colossalai/nn/layer/parallel_3d/_operation.py b/colossalai/nn/layer/parallel_3d/_operation.py index 8cf95f519..01251535f 100644 --- a/colossalai/nn/layer/parallel_3d/_operation.py +++ b/colossalai/nn/layer/parallel_3d/_operation.py @@ -276,6 +276,11 @@ def split_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Te The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found in `parallel_mode `_. """ + dim_size = tensor.size(dim) + world_size = gpc.get_world_size(parallel_mode) + assert dim_size % world_size == 0, \ + f'The dimension {dim} to split, size ({dim_size}) is not a multiple of world size ({world_size}), ' \ + f'cannot split tensor evenly' if tensor.size(dim) <= 1: return tensor output = torch.chunk(tensor, gpc.get_world_size(parallel_mode), @@ -302,13 +307,20 @@ def split_batch_3d(input_: Tensor, The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found in `parallel_mode `_. """ - if input_.size(dim) <= 1: - return input_ + dim_size = input_.size(dim) weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) - output = torch.chunk(input_, gpc.get_world_size(weight_parallel_mode), + weight_world_size = gpc.get_world_size(weight_parallel_mode) + input_world_size = gpc.get_world_size(input_parallel_mode) + + assert dim_size % (input_world_size*weight_world_size) == 0, \ + f'The batch size ({dim_size}) is not a multiple of square of 3D depth ({input_world_size*weight_world_size}).' + + if input_.size(dim) <= 1: + return input_ + output = torch.chunk(input_, weight_world_size, dim=dim)[gpc.get_local_rank(weight_parallel_mode)].contiguous() - output = torch.chunk(output, gpc.get_world_size(input_parallel_mode), + output = torch.chunk(output, input_world_size, dim=dim)[gpc.get_local_rank(input_parallel_mode)].contiguous() return output @@ -394,6 +406,11 @@ def reduce_scatter_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMo The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found in `parallel_mode `_ """ + dim_size = tensor.size(dim) + world_size = gpc.get_world_size(parallel_mode) + assert dim_size % world_size == 0, \ + f'The batch size ({dim_size}) is not a multiple of square of 3D depth ({world_size}).' + return _ReduceScatterTensor3D.apply(tensor, dim, parallel_mode) diff --git a/colossalai/nn/loss/loss_2d.py b/colossalai/nn/loss/loss_2d.py index 8e685bd3a..cb12e723c 100644 --- a/colossalai/nn/loss/loss_2d.py +++ b/colossalai/nn/loss/loss_2d.py @@ -2,7 +2,7 @@ import torch import torch.distributed as dist from colossalai.context import ParallelMode from colossalai.core import global_context as gpc -from colossalai.nn.layer.parallel_2d import reduce_by_batch_2d, split_tensor_2d +from colossalai.nn.layer.parallel_2d import reduce_by_batch_2d, split_batch_2d from colossalai.nn.layer.parallel_2d._utils import assert_summa_initialization from colossalai.registry import LOSSES from colossalai.utils import get_current_device @@ -48,7 +48,7 @@ class CrossEntropyLoss2D(_Loss): Returns: float: the loss between logits and targets. """ - targets = split_tensor_2d(targets) + targets = split_batch_2d(targets) loss = cross_entropy(logits, targets, reduction='none', *self.loss_args, **self.loss_kwargs) if self.reduction_mean: loss = loss.mean() @@ -145,7 +145,7 @@ class VocabParallelCrossEntropyLoss2D(_Loss): logits (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits). targets (:class:`torch.tensor`): Ground truth class indices or class probabilities. """ - targets = split_tensor_2d(targets) + targets = split_batch_2d(targets) loss = _VocabParallelCrossEntropy2D.apply( logits, targets, diff --git a/colossalai/nn/loss/loss_2p5d.py b/colossalai/nn/loss/loss_2p5d.py index 5aadcef1f..ed58c13f8 100644 --- a/colossalai/nn/loss/loss_2p5d.py +++ b/colossalai/nn/loss/loss_2p5d.py @@ -2,7 +2,7 @@ import torch import torch.distributed as dist from colossalai.context import ParallelMode from colossalai.core import global_context as gpc -from colossalai.nn.layer.parallel_2p5d import reduce_by_batch_2p5d, split_tensor_2p5d +from colossalai.nn.layer.parallel_2p5d import reduce_by_batch_2p5d, split_batch_2p5d from colossalai.nn.layer.parallel_2p5d._utils import assert_tesseract_initialization from colossalai.registry import LOSSES from colossalai.utils import get_current_device @@ -44,7 +44,7 @@ class CrossEntropyLoss2p5D(_Loss): logits (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits). targets (:class:`torch.tensor`): Ground truth class indices or class probabilities. """ - targets = split_tensor_2p5d(targets) + targets = split_batch_2p5d(targets) loss = cross_entropy(logits, targets, reduction='none', *self.loss_args, **self.loss_kwargs) if self.reduction_mean: loss = loss.mean() @@ -138,7 +138,7 @@ class VocabParallelCrossEntropyLoss2p5D(_Loss): logits (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits). targets (:class:`torch.tensor`): Ground truth class indices or class probabilities. """ - targets = split_tensor_2p5d(targets) + targets = split_batch_2p5d(targets) loss = _VocabParallelCrossEntropy2p5D.apply(logits, targets) if self.reduction_mean: loss = loss.mean() diff --git a/colossalai/nn/metric/accuracy_2d.py b/colossalai/nn/metric/accuracy_2d.py index 95834aec5..1137d1963 100644 --- a/colossalai/nn/metric/accuracy_2d.py +++ b/colossalai/nn/metric/accuracy_2d.py @@ -1,5 +1,5 @@ import torch -from colossalai.nn.layer.parallel_2d import reduce_by_batch_2d, split_tensor_2d +from colossalai.nn.layer.parallel_2d import reduce_by_batch_2d, split_batch_2d from torch import nn from ._utils import calc_acc @@ -22,7 +22,7 @@ class Accuracy2D(nn.Module): float: the accuracy of prediction. """ with torch.no_grad(): - targets = split_tensor_2d(targets) + targets = split_batch_2d(targets) correct = calc_acc(logits, targets) correct = reduce_by_batch_2d(correct) return correct diff --git a/colossalai/nn/metric/accuracy_2p5d.py b/colossalai/nn/metric/accuracy_2p5d.py index 08ce90083..337c6af4a 100644 --- a/colossalai/nn/metric/accuracy_2p5d.py +++ b/colossalai/nn/metric/accuracy_2p5d.py @@ -1,5 +1,5 @@ import torch -from colossalai.nn.layer.parallel_2p5d import reduce_by_batch_2p5d, split_tensor_2p5d +from colossalai.nn.layer.parallel_2p5d import reduce_by_batch_2p5d, split_batch_2p5d from torch import nn from ._utils import calc_acc @@ -22,7 +22,7 @@ class Accuracy2p5D(nn.Module): float: the accuracy of prediction. """ with torch.no_grad(): - targets = split_tensor_2p5d(targets) + targets = split_batch_2p5d(targets) correct = calc_acc(logits, targets) correct = reduce_by_batch_2p5d(correct) return correct