From f22ddacef03c00ac19ac20ca7d5274f9ab4c9ff1 Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Wed, 21 Jun 2023 14:30:06 +0800 Subject: [PATCH] [shardformer] refactored the shardformer layer structure (#4053) --- .../shardformer/{utils/utils.py => _utils.py} | 3 + colossalai/shardformer/layer/__init__.py | 19 +- colossalai/shardformer/layer/_operation.py | 2 - colossalai/shardformer/layer/dropout.py | 4 +- ...cabparallelembedding1d.py => embedding.py} | 165 +++++++++++++++--- colossalai/shardformer/layer/embedding1d.py | 157 ----------------- colossalai/shardformer/layer/layernorm1d.py | 73 -------- .../layer/{linear1d.py => linear.py} | 22 +-- .../layer/{linearconv1d.py => linear_conv.py} | 53 +++--- .../layer/{dist_crossentropy.py => loss.py} | 4 +- .../{parallelmodule.py => parallel_module.py} | 10 +- colossalai/shardformer/policies/basepolicy.py | 2 +- colossalai/shardformer/policies/bert.py | 2 +- colossalai/shardformer/policies/gpt2.py | 5 - colossalai/shardformer/policies/t5.py | 2 - colossalai/shardformer/shard/sharder.py | 2 +- colossalai/shardformer/utils/__init__.py | 1 - .../test_dist_crossentropy.py} | 4 +- .../test_layer/test_dropout.py | 2 +- .../test_layer/test_embedding.py | 2 +- .../test_layer/test_linear_1d.py | 2 +- .../test_vocab_parallel_embedding_1d.py | 2 +- .../test_module/test_dropout.py | 51 ------ .../test_module/test_slicer.py | 78 --------- 24 files changed, 196 insertions(+), 471 deletions(-) rename colossalai/shardformer/{utils/utils.py => _utils.py} (97%) rename colossalai/shardformer/layer/{vocabparallelembedding1d.py => embedding.py} (52%) delete mode 100644 colossalai/shardformer/layer/embedding1d.py delete mode 100644 colossalai/shardformer/layer/layernorm1d.py rename colossalai/shardformer/layer/{linear1d.py => linear.py} (96%) rename colossalai/shardformer/layer/{linearconv1d.py => linear_conv.py} (92%) rename colossalai/shardformer/layer/{dist_crossentropy.py => loss.py} (98%) rename colossalai/shardformer/layer/{parallelmodule.py => parallel_module.py} (78%) delete mode 100644 colossalai/shardformer/utils/__init__.py rename tests/test_shardformer/{test_module/test_distcrossentropy.py => test_layer/test_dist_crossentropy.py} (87%) delete mode 100644 tests/test_shardformer/test_module/test_dropout.py delete mode 100644 tests/test_shardformer/test_module/test_slicer.py diff --git a/colossalai/shardformer/utils/utils.py b/colossalai/shardformer/_utils.py similarity index 97% rename from colossalai/shardformer/utils/utils.py rename to colossalai/shardformer/_utils.py index 05a6a3ae6..a1c7203a9 100644 --- a/colossalai/shardformer/utils/utils.py +++ b/colossalai/shardformer/_utils.py @@ -2,6 +2,9 @@ import re def get_obj_list_element(obj, a): + r""" + Get the element of the list in the object + """ re_pattern = r'\[\d+\]' prog = re.compile(re_pattern) result = prog.search(a) diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py index 66d86913b..808ebbc12 100644 --- a/colossalai/shardformer/layer/__init__.py +++ b/colossalai/shardformer/layer/__init__.py @@ -1,17 +1,10 @@ from .dropout import Dropout1D -from .embedding1d import Embedding1D -from .layernorm1d import LayerNorm1D -from .linear1d import Linear1D_Col, Linear1D_Row -from .linearconv1d import LinearConv1D_Col, LinearConv1D_Row -from .vocabparallelembedding1d import VocabParallelEmbedding1D +from .embedding import Embedding1D, VocabParallelEmbedding1D +from .linear import Linear1D_Col, Linear1D_Row +from .linear_conv import LinearConv1D_Col, LinearConv1D_Row +from .loss import cross_entropy_1d __all__ = [ - "Embedding1D", - "VocabParallelEmbedding1D", - "Linear1D_Col", - "Linear1D_Row", - "LinearConv1D_Col", - "LinearConv1D_Row", - "LayerNorm1D", - "Dropout1D", + "Embedding1D", "VocabParallelEmbedding1D", "Linear1D_Col", "Linear1D_Row", "LinearConv1D_Col", "LinearConv1D_Row", + "Dropout1D", "cross_entropy_1d" ] diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index 208a391c3..280d55263 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -1,8 +1,6 @@ import torch import torch.distributed as dist -from colossalai.core import global_context as gpc - try: import fused_mix_prec_layer_norm_cuda except: diff --git a/colossalai/shardformer/layer/dropout.py b/colossalai/shardformer/layer/dropout.py index 08dfb8afd..2c49b49fa 100644 --- a/colossalai/shardformer/layer/dropout.py +++ b/colossalai/shardformer/layer/dropout.py @@ -4,9 +4,11 @@ import torch import torch.nn as nn from torch.distributed import ProcessGroup -from .parallelmodule import ParallelModule +from .parallel_module import ParallelModule from .utils import create_randomizer_with_offset +__all__ = ['Dropout1D'] + class Dropout1D(ParallelModule, nn.Dropout): """ diff --git a/colossalai/shardformer/layer/vocabparallelembedding1d.py b/colossalai/shardformer/layer/embedding.py similarity index 52% rename from colossalai/shardformer/layer/vocabparallelembedding1d.py rename to colossalai/shardformer/layer/embedding.py index 4c325c684..8b9fb03ec 100644 --- a/colossalai/shardformer/layer/vocabparallelembedding1d.py +++ b/colossalai/shardformer/layer/embedding.py @@ -1,7 +1,6 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -from collections import OrderedDict from typing import Callable, List, Union import torch @@ -12,26 +11,148 @@ from torch import Tensor from torch.distributed import ProcessGroup from torch.nn.parameter import Parameter -from colossalai.context import ParallelMode, seed from colossalai.nn import init as init -from colossalai.nn.layer.base_layer import ParallelLayer from colossalai.nn.layer.utils import divide -from colossalai.tensor.d_tensor.api import shard_rowwise -from colossalai.utils.checkpointing import gather_tensor_parallel_state_dict +from colossalai.tensor.d_tensor.api import shard_colwise, shard_rowwise +from colossalai.utils.cuda import get_current_device -from ._operation import reduce_input -from .parallelmodule import ParallelModule +from ._operation import gather_forward_split_backward, reduce_input +from .parallel_module import ParallelModule from .utils import create_randomizer_with_offset -Fast_LN = None -try: - from apex.contrib.layer_norm.layer_norm import FastLayerNorm - Fast_LN = FastLayerNorm -except ImportError: - pass +__all__ = ['Embedding1D', 'VocabParallelEmbedding1D'] -class VocabParallelEmbedding1D(ParallelLayer): +class Embedding1D(ParallelModule): + r"""Embedding for 1D parallelism. + + Args: + num_embeddings (int): number of embeddings. + embedding_dim (int): dimension of embedding. + padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient; + therefore, the embedding vector at padding_idx is not updated during training, + i.e. it remains as a fixed “pad”, defaults to None. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + weight_initializer (:class:`typing.Callable`, optional): + he initializer of weight, defaults to normal initializer. + + The ``args`` and ``kwargs`` used in :class:`torch.nn.functional.embedding` should contain: + :: + + max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is + renormalized to have norm max_norm. Note: this will modify weight in-place. + norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2. + scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse + of frequency of the words in the mini-batch. Default False. + sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False. + + More details about ``args`` and ``kwargs`` could be found in + `Embedding `_. + + More details about ``initializer`` please refer to + `init `_ + """ + + def __init__(self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int = None, + dtype: torch.dtype = None, + device: torch.device = None, + process_group: ProcessGroup = None, + gather_output: bool = True, + weight_initializer: Callable = init.normal_(), + *args, + **kwargs): + super().__init__() + + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + self.process_group = process_group + self.num_partitions = dist.get_world_size(process_group) + self.embed_dim_per_partition = divide(embedding_dim, self.num_partitions) + + self.padding_idx = padding_idx + self.embed_args = args + self.embed_kwargs = kwargs + self.gather_output = gather_output + + if device is None: + device = get_current_device() + + self.weight = Parameter(torch.empty((num_embeddings, self.embed_dim_per_partition), device=device, dtype=dtype)) + + # offset the seed with randomizer index and rank + seed = torch.random.initial_seed() + self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) + + with self.randomizer.fork_rng(enable_cpu=True): + self.reset_parameters(weight_initializer) + + @staticmethod + def from_native_module(module: nn.Embedding, + process_group: Union[ProcessGroup, List[ProcessGroup]] = None, + *args, + **kwargs) -> "Embedding1D": + r""" + Build a 1D parallelized Embedding from a native nn.Embedding module. + """ + # get the attributes + num_embedding = module.num_embeddings + embedding_dim = module.embedding_dim + padding_idx = module.padding_idx + max_norm = module.max_norm + norm_type = module.norm_type + scale_grad_by_freq = module.scale_grad_by_freq + sparse = module.sparse + dtype = module.weight.dtype + device = module.weight.device + + # sparse is not support yet + if sparse: + raise NotImplementedError("The Embedding1D module does not support sparse embedding yet.") + + embedding = Embedding1D(num_embeddings=num_embedding, + embedding_dim=embedding_dim, + padding_idx=padding_idx, + process_group=process_group, + dtype=dtype, + device=device, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + sparse=sparse, + *args, + **kwargs) + + # copy the weight + with torch.no_grad(): + sharded_weight = shard_colwise(module.weight.data, process_group) + embedding.weight.copy_(sharded_weight) + + return embedding + + def reset_parameters(self, weight_initializer) -> None: + fan_in, fan_out = self.num_embeddings, self.embedding_dim + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + self._fill_padding_idx_with_zero() + + def _fill_padding_idx_with_zero(self) -> None: + if self.padding_idx is not None: + with torch.no_grad(): + self.weight[self.padding_idx].fill_(0) + + def forward(self, input_: Tensor) -> Tensor: + output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) + + if self.gather_output: + output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group) + return output + else: + return output_parallel + + +class VocabParallelEmbedding1D(ParallelModule): r"""Embedding parallelized in the vocabulary dimension. Args: @@ -93,9 +214,7 @@ class VocabParallelEmbedding1D(ParallelLayer): # offset the seed with randomizer index and rank seed = torch.random.initial_seed() self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) - - with self.randomizer.fork_rng(enable_cpu=True): - self.reset_parameters(weight_initializer) + self.reset_parameters(weight_initializer) @staticmethod def from_native_module(module: nn.Embedding, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, @@ -132,7 +251,7 @@ class VocabParallelEmbedding1D(ParallelLayer): return vocab_embedding_1d def reset_parameters(self, weight_initializer) -> None: - with seed(ParallelMode.TENSOR): + with self.randomizer.fork_rng(enable_cpu=True): fan_in, fan_out = self.num_embeddings, self.embed_dim weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) self._fill_padding_idx_with_zero() @@ -143,16 +262,6 @@ class VocabParallelEmbedding1D(ParallelLayer): with torch.no_grad(): self.weight[self.padding_idx - self.vocab_start_index].fill_(0) - def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' - local_state = OrderedDict({weight_key: self.weight}) - local_state = gather_tensor_parallel_state_dict(local_state, - ParallelMode.PARALLEL_1D, - dims={weight_key: 0}, - partition_states={weight_key: True}, - keep_vars=keep_vars) - destination.update(local_state) - def forward(self, input_: Tensor) -> Tensor: # Build the mask. input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index) diff --git a/colossalai/shardformer/layer/embedding1d.py b/colossalai/shardformer/layer/embedding1d.py deleted file mode 100644 index ace7deb3a..000000000 --- a/colossalai/shardformer/layer/embedding1d.py +++ /dev/null @@ -1,157 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -from typing import Callable, List, Union - -import torch -import torch.distributed as dist -import torch.nn as nn -import torch.nn.functional as F -from torch import Tensor -from torch.distributed import ProcessGroup -from torch.nn.parameter import Parameter - -from colossalai.nn import init as init -from colossalai.nn.layer.utils import divide -from colossalai.tensor.d_tensor.api import shard_colwise -from colossalai.utils.cuda import get_current_device - -from ._operation import gather_forward_split_backward -from .parallelmodule import ParallelModule -from .utils import create_randomizer_with_offset - -Fast_LN = None -try: - from apex.contrib.layer_norm.layer_norm import FastLayerNorm - Fast_LN = FastLayerNorm -except ImportError: - pass - - -class Embedding1D(ParallelModule): - r"""Embedding for 1D parallelism. - - Args: - num_embeddings (int): number of embeddings. - embedding_dim (int): dimension of embedding. - padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient; - therefore, the embedding vector at padding_idx is not updated during training, - i.e. it remains as a fixed “pad”, defaults to None. - dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. - weight_initializer (:class:`typing.Callable`, optional): - he initializer of weight, defaults to normal initializer. - - The ``args`` and ``kwargs`` used in :class:`torch.nn.functional.embedding` should contain: - :: - - max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is - renormalized to have norm max_norm. Note: this will modify weight in-place. - norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2. - scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse - of frequency of the words in the mini-batch. Default False. - sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False. - - More details about ``args`` and ``kwargs`` could be found in - `Embedding `_. - - More details about ``initializer`` please refer to - `init `_ - """ - - def __init__(self, - num_embeddings: int, - embedding_dim: int, - padding_idx: int = None, - dtype: torch.dtype = None, - device: torch.device = None, - process_group: ProcessGroup = None, - gather_output: bool = True, - weight_initializer: Callable = init.normal_(), - *args, - **kwargs): - super().__init__() - - self.num_embeddings = num_embeddings - self.embedding_dim = embedding_dim - self.process_group = process_group - self.num_partitions = dist.get_world_size(process_group) - self.embed_dim_per_partition = divide(embedding_dim, self.num_partitions) - - self.padding_idx = padding_idx - self.embed_args = args - self.embed_kwargs = kwargs - self.gather_output = gather_output - - if device is None: - device = get_current_device() - - self.weight = Parameter(torch.empty((num_embeddings, self.embed_dim_per_partition), device=device, dtype=dtype)) - - # offset the seed with randomizer index and rank - seed = torch.random.initial_seed() - self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) - - with self.randomizer.fork_rng(enable_cpu=True): - self.reset_parameters(weight_initializer) - - @staticmethod - def from_native_module(module: nn.Embedding, - process_group: Union[ProcessGroup, List[ProcessGroup]] = None, - *args, - **kwargs) -> "Embedding1D": - r""" - Build a 1D parallelized Embedding from a native nn.Embedding module. - """ - # get the attributes - num_embedding = module.num_embeddings - embedding_dim = module.embedding_dim - padding_idx = module.padding_idx - max_norm = module.max_norm - norm_type = module.norm_type - scale_grad_by_freq = module.scale_grad_by_freq - sparse = module.sparse - dtype = module.weight.dtype - device = module.weight.device - - # sparse is not support yet - if sparse: - raise NotImplementedError("The Embedding1D module does not support sparse embedding yet.") - - embedding = Embedding1D(num_embeddings=num_embedding, - embedding_dim=embedding_dim, - padding_idx=padding_idx, - process_group=process_group, - dtype=dtype, - device=device, - max_norm=max_norm, - norm_type=norm_type, - scale_grad_by_freq=scale_grad_by_freq, - sparse=sparse, - *args, - **kwargs) - - # copy the weight - with torch.no_grad(): - sharded_weight = shard_colwise(module.weight.data, process_group) - embedding.weight.copy_(sharded_weight) - - return embedding - - def reset_parameters(self, weight_initializer) -> None: - fan_in, fan_out = self.num_embeddings, self.embedding_dim - weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) - self._fill_padding_idx_with_zero() - - def _fill_padding_idx_with_zero(self) -> None: - if self.padding_idx is not None: - with torch.no_grad(): - self.weight[self.padding_idx].fill_(0) - - def forward(self, input_: Tensor) -> Tensor: - output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) - - if self.gather_output: - output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group) - return output - else: - return output_parallel diff --git a/colossalai/shardformer/layer/layernorm1d.py b/colossalai/shardformer/layer/layernorm1d.py deleted file mode 100644 index 78bd64cfb..000000000 --- a/colossalai/shardformer/layer/layernorm1d.py +++ /dev/null @@ -1,73 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -from collections import OrderedDict - -from colossalai.context import ParallelMode, seed -from colossalai.core import global_context as gpc -from colossalai.global_variables import tensor_parallel_env as env -from colossalai.kernel import LayerNorm -from colossalai.nn import init as init -from colossalai.nn.layer.colossalai_layer._utils import ColossalaiModule -from colossalai.utils.checkpointing import broadcast_state_dict - -Fast_LN = None -try: - from apex.contrib.layer_norm.layer_norm import FastLayerNorm - Fast_LN = FastLayerNorm -except ImportError: - pass - - -class LayerNorm1D(ColossalaiModule): - r""" - Layer Normalization for colossalai - - Args: - normalized_shape (int): input shape from an expected input of size. - :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] - \times \ldots \times \text{normalized_shape}[-1]]` - If a single integer is used, it is treated as a singleton list, and this module will - normalize over the last dimension which is expected to be of that specific size. - eps (float): a value added to the denominator for numerical stability, defaults to 1e-05. - bias (bool, optional): Whether to add a bias, defaults to ``True``. - dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. - """ - - _fast_ln_supported_sizes = [ - 1024, 1536, 2048, 2304, 3072, 3840, 4096, 5120, 6144, 8192, 10240, 12288, 12800, 15360, 16384, 18432, 20480, - 24576, 25600, 30720, 32768, 40960, 49152, 65536 - ] - - def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None): - if Fast_LN is not None and normalized_shape in self._fast_ln_supported_sizes: - norm = Fast_LN(normalized_shape, eps=eps).to(dtype) - else: - norm = None - try: - from apex.normalization import FusedLayerNorm - norm = FusedLayerNorm(normalized_shape, eps=eps).to(dtype) - except ImportError: - norm = LayerNorm(normalized_shape, eps=eps).to(dtype) - super().__init__(norm) - - def _load_from_state_dict(self, state_dict, prefix, *args): - local_state = OrderedDict() - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - # weight - weight = state_dict.pop(weight_key, None) - if weight is not None: - local_state[weight_key] = weight - # bias - bias = state_dict.pop(bias_key, None) - if bias is not None: - local_state[bias_key] = bias - - local_state = broadcast_state_dict(local_state, ParallelMode.PARALLEL_1D) - super()._load_from_state_dict(local_state, prefix, *args) - - def _save_to_state_dict(self, destination, prefix, keep_vars): - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - super()._save_to_state_dict(destination, prefix, keep_vars) diff --git a/colossalai/shardformer/layer/linear1d.py b/colossalai/shardformer/layer/linear.py similarity index 96% rename from colossalai/shardformer/layer/linear1d.py rename to colossalai/shardformer/layer/linear.py index d59d32df8..b87981c6d 100644 --- a/colossalai/shardformer/layer/linear1d.py +++ b/colossalai/shardformer/layer/linear.py @@ -23,15 +23,10 @@ from ._operation import ( reduce_input, split_forward_gather_backward, ) -from .parallelmodule import ParallelModule +from .parallel_module import ParallelModule from .utils import create_randomizer_with_offset -Fast_LN = None -try: - from apex.contrib.layer_norm.layer_norm import FastLayerNorm - Fast_LN = FastLayerNorm -except ImportError: - pass +__all__ = ['Linear1D_Col', 'Linear1D_Row'] class Linear1D_Col(ParallelModule): @@ -104,8 +99,8 @@ class Linear1D_Col(ParallelModule): seed = torch.random.initial_seed() self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) - with self.randomizer.fork_rng(enable_cpu=True): - self.reset_parameters(weight_initializer, bias_initializer) + # init weights + self.reset_parameters(weight_initializer, bias_initializer) @staticmethod def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, @@ -146,10 +141,11 @@ class Linear1D_Col(ParallelModule): return linear_1d def reset_parameters(self, weight_initializer, bias_initializer) -> None: - fan_in, fan_out = self.in_features, self.out_features - weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) - if self.bias is not None: - bias_initializer(self.bias, fan_in=fan_in) + with self.randomizer.fork_rng(enable_cpu=True): + fan_in, fan_out = self.in_features, self.out_features + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: assert input_.shape[-1] == self.weight.shape[-1], \ diff --git a/colossalai/shardformer/layer/linearconv1d.py b/colossalai/shardformer/layer/linear_conv.py similarity index 92% rename from colossalai/shardformer/layer/linearconv1d.py rename to colossalai/shardformer/layer/linear_conv.py index 4a5cb0707..b4599f489 100644 --- a/colossalai/shardformer/layer/linearconv1d.py +++ b/colossalai/shardformer/layer/linear_conv.py @@ -23,19 +23,15 @@ from ._operation import ( reduce_input, split_forward_gather_backward, ) -from .parallelmodule import ParallelModule +from .parallel_module import ParallelModule from .utils import create_randomizer_with_offset -Fast_LN = None -try: - from apex.contrib.layer_norm.layer_norm import FastLayerNorm - Fast_LN = FastLayerNorm -except ImportError: - pass +__all__ = ['LinearConv1D_Col', 'LinearConv1D_Row'] class LinearConv1D_Col(ParallelModule): r"""Linear layer with column parallelism. + Specially created for HuggingFace's GPT2 model. The linear layer is defined as :math:`Y = XA + b`. A is parallelized along its second dimension as :math:`A = [A_1, ..., A_p]`. This layer is used to fit `Conv1D` layer in gpt2 of huggingface. @@ -104,8 +100,8 @@ class LinearConv1D_Col(ParallelModule): seed = torch.random.initial_seed() self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) - with self.randomizer.fork_rng(enable_cpu=True): - self.reset_parameters(weight_initializer, bias_initializer) + # init weights + self.reset_parameters(weight_initializer, bias_initializer) @staticmethod def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], n_cast: int, @@ -162,10 +158,11 @@ class LinearConv1D_Col(ParallelModule): return linear_1d def reset_parameters(self, weight_initializer, bias_initializer) -> None: - fan_in, fan_out = self.in_features, self.out_features - weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) - if self.bias is not None: - bias_initializer(self.bias, fan_in=fan_in) + with self.randomizer.fork_rng(enable_cpu=True): + fan_in, fan_out = self.in_features, self.out_features + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: assert input_.shape[-1] == self.weight.shape[-1], \ @@ -192,6 +189,7 @@ class LinearConv1D_Col(ParallelModule): class LinearConv1D_Row(ParallelModule): r""" Linear layer with row parallelism + Specially created for HuggingFace's GPT2 model. Args: in_features (int): size of each input sample. @@ -260,8 +258,8 @@ class LinearConv1D_Row(ParallelModule): seed = torch.random.initial_seed() self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) - with self.randomizer.fork_rng(enable_cpu=True): - self.reset_parameters(weight_initializer, bias_initializer) + # init weights + self.reset_parameters(weight_initializer, bias_initializer) @staticmethod def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], n_cast: int, @@ -320,20 +318,21 @@ class LinearConv1D_Row(ParallelModule): self.weight_list = torch.chunk(self.weight, self.stream_chunk_num, dim=0) def reset_parameters(self, weight_initializer, bias_initializer) -> None: - fan_in, fan_out = self.in_features, self.out_features - weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + with self.randomizer.fork_rng(enable_cpu=True): + fan_in, fan_out = self.in_features, self.out_features + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) - if self.bias is not None: - bias_initializer(self.bias, fan_in=fan_in) - if self.process_group is None: - src_rank = 0 - else: - src_rank = dist.distributed_c10d._get_global_rank(self.process_group, 0) + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + if self.process_group is None: + src_rank = 0 + else: + src_rank = dist.distributed_c10d._get_global_rank(self.process_group, 0) - origin_device = self.bias.device - self.bias = self.bias.cuda() - dist.broadcast(self.bias, src=src_rank, group=self.process_group) - self.bias = self.bias.to(origin_device) + origin_device = self.bias.device + self.bias = self.bias.cuda() + dist.broadcast(self.bias, src=src_rank, group=self.process_group) + self.bias = self.bias.to(origin_device) def forward(self, input_: Tensor) -> Tensor: # Set up backprop all-reduce. diff --git a/colossalai/shardformer/layer/dist_crossentropy.py b/colossalai/shardformer/layer/loss.py similarity index 98% rename from colossalai/shardformer/layer/dist_crossentropy.py rename to colossalai/shardformer/layer/loss.py index 7840c2f2e..38a5395a0 100644 --- a/colossalai/shardformer/layer/dist_crossentropy.py +++ b/colossalai/shardformer/layer/loss.py @@ -1,10 +1,10 @@ import torch import torch.distributed as dist -import torch.nn as nn -import torch.nn.functional as F from torch.autograd import Function from torch.distributed import ProcessGroup +__all__ = ['DistCrossEntropy', 'cross_entropy_1d'] + class DistCrossEntropy(Function): r""" diff --git a/colossalai/shardformer/layer/parallelmodule.py b/colossalai/shardformer/layer/parallel_module.py similarity index 78% rename from colossalai/shardformer/layer/parallelmodule.py rename to colossalai/shardformer/layer/parallel_module.py index 3d19bbea7..c68cd5778 100644 --- a/colossalai/shardformer/layer/parallelmodule.py +++ b/colossalai/shardformer/layer/parallel_module.py @@ -7,15 +7,7 @@ from typing import List, Union import torch.nn as nn from torch.distributed import ProcessGroup -from colossalai.global_variables import tensor_parallel_env as env -from colossalai.nn import init as init - -Fast_LN = None -try: - from apex.contrib.layer_norm.layer_norm import FastLayerNorm - Fast_LN = FastLayerNorm -except ImportError: - pass +__all__ = ['ParallelModule'] class ParallelModule(nn.Module, ABC): diff --git a/colossalai/shardformer/policies/basepolicy.py b/colossalai/shardformer/policies/basepolicy.py index 175a914a8..b5d9cdbd7 100644 --- a/colossalai/shardformer/policies/basepolicy.py +++ b/colossalai/shardformer/policies/basepolicy.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Tuple, Type, Union +from typing import Any, Callable, Dict, List, Type, Union import torch.nn as nn diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 2a204f0de..d5e8e01cf 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -4,7 +4,7 @@ from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer, Be import colossalai.shardformer.layer as col_nn from colossalai.shardformer.layer.dropout import Dropout1D -from ..utils import getattr_, setattr_ +from .._utils import getattr_, setattr_ from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index d255325b2..da9e6b7bd 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -1,12 +1,7 @@ -from typing import Type, Union - -import torch.nn as nn from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2Model import colossalai.shardformer.layer as col_nn -from colossalai.shardformer.layer.dropout import Dropout1D -from ..utils import getattr_, setattr_ from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py index 9e0c86049..30433f751 100644 --- a/colossalai/shardformer/policies/t5.py +++ b/colossalai/shardformer/policies/t5.py @@ -1,5 +1,3 @@ -import torch -import torch.nn as nn from transformers import T5ForConditionalGeneration from transformers.models.t5.modeling_t5 import ( T5Attention, diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py index 66934b09b..22f5f1c12 100644 --- a/colossalai/shardformer/shard/sharder.py +++ b/colossalai/shardformer/shard/sharder.py @@ -4,9 +4,9 @@ import torch.nn as nn from colossalai.cluster.process_group_manager import ProcessGroupManager +from .._utils import getattr_, setattr_ from ..policies.autopolicy import get_autopolicy from ..policies.basepolicy import Policy, SubModuleReplacementDescription -from ..utils.utils import getattr_, setattr_ from .shard_config import ShardConfig __all__ = ['ModelSharder', 'shard_model'] diff --git a/colossalai/shardformer/utils/__init__.py b/colossalai/shardformer/utils/__init__.py deleted file mode 100644 index b50e7b2f6..000000000 --- a/colossalai/shardformer/utils/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .utils import getattr_, hasattr_, setattr_ diff --git a/tests/test_shardformer/test_module/test_distcrossentropy.py b/tests/test_shardformer/test_layer/test_dist_crossentropy.py similarity index 87% rename from tests/test_shardformer/test_module/test_distcrossentropy.py rename to tests/test_shardformer/test_layer/test_dist_crossentropy.py index 9a19ec578..72e6e5cf2 100644 --- a/tests/test_shardformer/test_module/test_distcrossentropy.py +++ b/tests/test_shardformer/test_layer/test_dist_crossentropy.py @@ -4,7 +4,7 @@ import torch.nn.functional as F import colossalai from colossalai.logging import disable_existing_loggers -from colossalai.shardformer.layer.dist_crossentropy import applyDistCrossEntropy +from colossalai.shardformer.layer import cross_entropy_1d from colossalai.testing import rerun_if_address_is_in_use, spawn CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')),) @@ -25,7 +25,7 @@ def check_dist_crossentropy(rank, world_size, port, ignore_index): org_loss = F.cross_entropy(org_pred, org_labels) dist_pred = pred.chunk(world_size, -1)[rank] - dist_loss = applyDistCrossEntropy(dist_pred.to('cuda'), labels.to('cuda'), ignore_index=ignore_index) + dist_loss = cross_entropy_1d(dist_pred.to('cuda'), labels.to('cuda'), ignore_index=ignore_index) assert torch.allclose(org_loss, dist_loss, atol=1e-5), f"dist cross entropy loss is not equal to orgin loss\n{org_loss}\n{dist_loss}" diff --git a/tests/test_shardformer/test_layer/test_dropout.py b/tests/test_shardformer/test_layer/test_dropout.py index c48c11b36..c62d25d94 100644 --- a/tests/test_shardformer/test_layer/test_dropout.py +++ b/tests/test_shardformer/test_layer/test_dropout.py @@ -3,7 +3,7 @@ import torch.distributed as dist import torch.nn as nn import colossalai -from colossalai.shardformer.layer.dropout import Dropout1D +from colossalai.shardformer.layer import Dropout1D from colossalai.testing import assert_equal, assert_not_equal, rerun_if_address_is_in_use, spawn diff --git a/tests/test_shardformer/test_layer/test_embedding.py b/tests/test_shardformer/test_layer/test_embedding.py index 462349ecb..70500008c 100644 --- a/tests/test_shardformer/test_layer/test_embedding.py +++ b/tests/test_shardformer/test_layer/test_embedding.py @@ -4,7 +4,7 @@ import torch.nn as nn from torch.testing import assert_close import colossalai -from colossalai.shardformer.layer.layers import Embedding1D +from colossalai.shardformer.layer import Embedding1D from colossalai.testing import rerun_if_address_is_in_use, spawn diff --git a/tests/test_shardformer/test_layer/test_linear_1d.py b/tests/test_shardformer/test_layer/test_linear_1d.py index 2a3ce9938..00ecc37ce 100644 --- a/tests/test_shardformer/test_layer/test_linear_1d.py +++ b/tests/test_shardformer/test_layer/test_linear_1d.py @@ -4,7 +4,7 @@ import torch.nn as nn from torch.testing import assert_close import colossalai -from colossalai.shardformer.layer.layers import Linear1D_Col, Linear1D_Row +from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row from colossalai.testing import rerun_if_address_is_in_use, spawn diff --git a/tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py b/tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py index 3df53e8a8..bee44a2fb 100644 --- a/tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py +++ b/tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py @@ -4,7 +4,7 @@ import torch.nn as nn from torch.testing import assert_close import colossalai -from colossalai.shardformer.layer.layers import VocabParallelEmbedding1D +from colossalai.shardformer.layer import VocabParallelEmbedding1D from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn diff --git a/tests/test_shardformer/test_module/test_dropout.py b/tests/test_shardformer/test_module/test_dropout.py deleted file mode 100644 index 4a13eb61c..000000000 --- a/tests/test_shardformer/test_module/test_dropout.py +++ /dev/null @@ -1,51 +0,0 @@ -import pytest -import torch -import torch.nn.functional as F - -import colossalai -from colossalai.logging import disable_existing_loggers -from colossalai.shardformer.layer.dropout import Dropout1D -from colossalai.testing import rerun_if_address_is_in_use, spawn - -CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')),) - - -def check_dropout(rank, world_size, port): - disable_existing_loggers() - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, port=port, host='localhost', backend='nccl') - - # prepare data - input = torch.randn(5, 4).to('cuda') - dropout = Dropout1D(p=0.4).to('cuda') - output_list = [] - # compare the dropout pattern in each device - for i in range(2): - output = dropout(input) - output_list.append(output) - dist_output_list = [torch.zeros(*output.shape).to('cuda') for _ in range(world_size)] - torch.distributed.all_gather(dist_output_list, output) - for j in range(world_size): - for k in range(world_size): - if j != k: - mask = torch.eq(dist_output_list[j], 0.0) == torch.eq(dist_output_list[k], 0.0) - assert torch.all( - mask - ) == False, f"The dropout pattern in each device is not unique\n{dist_output_list[j]}\n{dist_output_list[k]}" - # compare the dropout pattern in loacl device - for i in range(len(output_list)): - for j in range(len(output_list)): - if i != j: - mask = torch.eq(output_list[i], 0.0) == torch.eq(output_list[j], 0.0) - assert torch.all( - mask - ) == False, f"The dropout pattern in one device is not unique\n{output_list[i]}\n{output_list[j]}" - - -@pytest.mark.dist -@rerun_if_address_is_in_use() -def test_dropout(): - spawn(check_dropout, 2) - - -if __name__ == '__main__': - test_dropout() diff --git a/tests/test_shardformer/test_module/test_slicer.py b/tests/test_shardformer/test_module/test_slicer.py deleted file mode 100644 index c72a03575..000000000 --- a/tests/test_shardformer/test_module/test_slicer.py +++ /dev/null @@ -1,78 +0,0 @@ -import pytest -import torch -import torch.nn.functional as F - -import colossalai -from colossalai.logging import disable_existing_loggers -from colossalai.shardformer.policies.basepolicy import Col_Layer, Layer, Row_Layer -from colossalai.shardformer.shard.shard_config import ShardConfig -from colossalai.shardformer.shard.slicer import Slicer -from colossalai.testing import rerun_if_address_is_in_use, spawn - -CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')),) - - -def check_slicer(rank, world_size, port, in_feature, out_feature): - disable_existing_loggers() - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, port=port, host='localhost', backend='nccl') - # initialize slicer - shardconfig = ShardConfig(rank=rank, world_size=world_size) - slicer = Slicer(shardconfig) - # initialize test data - weight = torch.randn(in_feature, out_feature) - bias = torch.randn(out_feature) - policy_layer_cls_list = [Layer, Col_Layer, Row_Layer] - n_cast_list = [None, 2, 3, 4] - # weight and bias - for n_cast in n_cast_list: - sliced_weight, sliced_bias = slicer.slice_weight_bias(weight, bias, policy_layer_cls=Layer, n_cast=n_cast) - expected_sliced_weight = weight - expected_sliced_bias = bias - assert torch.equal( - sliced_weight, expected_sliced_weight - ), f"In Layer case, weight: sliced_weight is not equal to expected_sliced_weight\norg:{weight}\nsliced:{sliced_weight}\nexpected:{expected_sliced_weight}" - assert torch.equal( - sliced_bias, expected_sliced_bias - ), f"In Layer case, bias: sliced_bias is not equal to expected_sliced_bias\norg:{bias}\nsliced:{sliced_weight}\nexpected:{expected_sliced_weight}" - - sliced_weight, sliced_bias = slicer.slice_weight_bias(weight, bias, policy_layer_cls=Col_Layer, n_cast=n_cast) - if (n_cast is None): - expected_sliced_weight = weight.chunk(world_size, dim=0)[rank] - expected_sliced_bias = bias.chunk(world_size)[rank] - else: - chunks = weight.chunk(world_size * n_cast, dim=0) - expected_sliced_weight = torch.cat([chunks[i] for i in range(rank, n_cast * world_size, world_size)], dim=0) - chunks = bias.chunk(world_size * n_cast, dim=0) - expected_sliced_bias = torch.cat([chunks[i] for i in range(rank, n_cast * world_size, world_size)]) - assert torch.equal( - sliced_weight, expected_sliced_weight - ), f"In Col_Layer {n_cast} cast case, weight: sliced_weight is not equal to expected_sliced_weight\norg:{weight}\nsliced:{sliced_weight}\nexpected:{expected_sliced_weight}" - assert torch.equal( - sliced_bias, expected_sliced_bias - ), f"In Col_Layer {n_cast} cast case, bias: sliced_bias is not equal to expected_sliced_bias\norg:{bias}\nsliced:{sliced_bias}\nexpected:{expected_sliced_bias}" - - sliced_weight, sliced_bias = slicer.slice_weight_bias(weight, bias, policy_layer_cls=Row_Layer, n_cast=n_cast) - if (n_cast is None): - expected_sliced_weight = weight.chunk(world_size, dim=1)[rank] - expected_sliced_bias = bias - else: - chunks = weight.chunk(world_size * n_cast, dim=1) - expected_sliced_weight = torch.cat([chunks[i] for i in range(rank, n_cast * world_size, world_size)], dim=1) - expected_sliced_bias = bias - assert torch.equal( - sliced_weight, expected_sliced_weight - ), f"In Row_Layer {n_cast} cast case, weight: sliced_weight is not equal to expected_sliced_weight\norg:{weight}\nsliced:{sliced_weight}\nexpected:{expected_sliced_weight}" - assert torch.equal( - sliced_bias, expected_sliced_bias - ), f"In Row_Layer {n_cast} cast case, bias: sliced_bias is not equal to expected_sliced_bias\norg:{bias}\nsliced:{sliced_weight}\nexpected:{expected_sliced_weight}" - - -@pytest.mark.dist -@rerun_if_address_is_in_use() -def test_slicer(): - args = dict(in_feature=24, out_feature=48) - spawn(check_slicer, nprocs=2, in_feature=args['in_feature'], out_feature=args['out_feature']) - - -if __name__ == '__main__': - test_slicer()