diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py index e69de29bb..66d86913b 100644 --- a/colossalai/shardformer/layer/__init__.py +++ b/colossalai/shardformer/layer/__init__.py @@ -0,0 +1,17 @@ +from .dropout import Dropout1D +from .embedding1d import Embedding1D +from .layernorm1d import LayerNorm1D +from .linear1d import Linear1D_Col, Linear1D_Row +from .linearconv1d import LinearConv1D_Col, LinearConv1D_Row +from .vocabparallelembedding1d import VocabParallelEmbedding1D + +__all__ = [ + "Embedding1D", + "VocabParallelEmbedding1D", + "Linear1D_Col", + "Linear1D_Row", + "LinearConv1D_Col", + "LinearConv1D_Row", + "LayerNorm1D", + "Dropout1D", +] diff --git a/colossalai/shardformer/layer/dropout.py b/colossalai/shardformer/layer/dropout.py index ec08d072f..08dfb8afd 100644 --- a/colossalai/shardformer/layer/dropout.py +++ b/colossalai/shardformer/layer/dropout.py @@ -4,7 +4,7 @@ import torch import torch.nn as nn from torch.distributed import ProcessGroup -from .layers import ParallelModule +from .parallelmodule import ParallelModule from .utils import create_randomizer_with_offset diff --git a/colossalai/shardformer/layer/embedding1d.py b/colossalai/shardformer/layer/embedding1d.py new file mode 100644 index 000000000..1108d5d6a --- /dev/null +++ b/colossalai/shardformer/layer/embedding1d.py @@ -0,0 +1,149 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from typing import Callable, List, Union + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from torch.distributed import ProcessGroup +from torch.nn.parameter import Parameter + +from colossalai.nn import init as init +from colossalai.nn.layer.utils import divide +from colossalai.tensor.d_tensor.api import shard_colwise +from colossalai.utils.cuda import get_current_device + +from ._operation import gather_forward_split_backward +from .parallelmodule import ParallelModule +from .utils import create_randomizer_with_offset + +Fast_LN = None +try: + from apex.contrib.layer_norm.layer_norm import FastLayerNorm + Fast_LN = FastLayerNorm +except ImportError: + pass + + +class Embedding1D(ParallelModule): + r"""Embedding for 1D parallelism. + + Args: + num_embeddings (int): number of embeddings. + embedding_dim (int): dimension of embedding. + padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient; + therefore, the embedding vector at padding_idx is not updated during training, + i.e. it remains as a fixed “pad”, defaults to None. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + weight_initializer (:class:`typing.Callable`, optional): + he initializer of weight, defaults to normal initializer. + + The ``args`` and ``kwargs`` used in :class:`torch.nn.functional.embedding` should contain: + :: + + max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is + renormalized to have norm max_norm. Note: this will modify weight in-place. + norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2. + scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse + of frequency of the words in the mini-batch. Default False. + sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False. + + More details about ``args`` and ``kwargs`` could be found in + `Embedding `_. + + More details about ``initializer`` please refer to + `init `_ + """ + + def __init__(self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int = None, + dtype: torch.dtype = None, + device: torch.device = None, + process_group: ProcessGroup = None, + weight_initializer: Callable = init.normal_(), + *args, + **kwargs): + super().__init__() + + self.num_embeddings = num_embeddings + self.embed_dim = embedding_dim + self.process_group = process_group + self.num_partitions = dist.get_world_size(process_group) + self.embed_dim_per_partition = divide(embedding_dim, self.num_partitions) + + self.padding_idx = padding_idx + self.embed_args = args + self.embed_kwargs = kwargs + # self.gather_output = gather_output + + if device is None: + device = get_current_device() + + self.weight = Parameter(torch.empty((num_embeddings, self.embed_dim_per_partition), device=device, dtype=dtype)) + + # offset the seed with randomizer index and rank + seed = torch.random.initial_seed() + self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) + + with self.randomizer.fork_rng(enable_cpu=True): + self.reset_parameters(weight_initializer) + + @staticmethod + def from_native_module(module: nn.Embedding, + process_group: Union[ProcessGroup, List[ProcessGroup]] = None) -> "Embedding1D": + r""" + Build a 1D parallelized Embedding from a native nn.Embedding module. + """ + # get the attributes + num_embedding = module.num_embeddings + embedding_dim = module.embedding_dim + padding_idx = module.padding_idx + max_norm = module.max_norm + norm_type = module.norm_type + scale_grad_by_freq = module.scale_grad_by_freq + sparse = module.sparse + dtype = module.weight.dtype + device = module.weight.device + + # sparse is not support yet + if sparse: + raise NotImplementedError("The Embedding1D module does not support sparse embedding yet.") + + embedding = Embedding1D(num_embeddings=num_embedding, + embedding_dim=embedding_dim, + padding_idx=padding_idx, + process_group=process_group, + dtype=dtype, + device=device, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + sparse=sparse) + + # copy the weight + with torch.no_grad(): + sharded_weight = shard_colwise(module.weight.data, process_group) + embedding.weight.copy_(sharded_weight) + + return embedding + + def reset_parameters(self, weight_initializer) -> None: + fan_in, fan_out = self.num_embeddings, self.embed_dim + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + self._fill_padding_idx_with_zero() + + def _fill_padding_idx_with_zero(self) -> None: + if self.padding_idx is not None: + with torch.no_grad(): + self.weight[self.padding_idx].fill_(0) + + def forward(self, input_: Tensor) -> Tensor: + output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) + output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group) + + return output diff --git a/colossalai/shardformer/layer/layernorm1d.py b/colossalai/shardformer/layer/layernorm1d.py new file mode 100644 index 000000000..78bd64cfb --- /dev/null +++ b/colossalai/shardformer/layer/layernorm1d.py @@ -0,0 +1,73 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from collections import OrderedDict + +from colossalai.context import ParallelMode, seed +from colossalai.core import global_context as gpc +from colossalai.global_variables import tensor_parallel_env as env +from colossalai.kernel import LayerNorm +from colossalai.nn import init as init +from colossalai.nn.layer.colossalai_layer._utils import ColossalaiModule +from colossalai.utils.checkpointing import broadcast_state_dict + +Fast_LN = None +try: + from apex.contrib.layer_norm.layer_norm import FastLayerNorm + Fast_LN = FastLayerNorm +except ImportError: + pass + + +class LayerNorm1D(ColossalaiModule): + r""" + Layer Normalization for colossalai + + Args: + normalized_shape (int): input shape from an expected input of size. + :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] + \times \ldots \times \text{normalized_shape}[-1]]` + If a single integer is used, it is treated as a singleton list, and this module will + normalize over the last dimension which is expected to be of that specific size. + eps (float): a value added to the denominator for numerical stability, defaults to 1e-05. + bias (bool, optional): Whether to add a bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + """ + + _fast_ln_supported_sizes = [ + 1024, 1536, 2048, 2304, 3072, 3840, 4096, 5120, 6144, 8192, 10240, 12288, 12800, 15360, 16384, 18432, 20480, + 24576, 25600, 30720, 32768, 40960, 49152, 65536 + ] + + def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None): + if Fast_LN is not None and normalized_shape in self._fast_ln_supported_sizes: + norm = Fast_LN(normalized_shape, eps=eps).to(dtype) + else: + norm = None + try: + from apex.normalization import FusedLayerNorm + norm = FusedLayerNorm(normalized_shape, eps=eps).to(dtype) + except ImportError: + norm = LayerNorm(normalized_shape, eps=eps).to(dtype) + super().__init__(norm) + + def _load_from_state_dict(self, state_dict, prefix, *args): + local_state = OrderedDict() + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + # bias + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + + local_state = broadcast_state_dict(local_state, ParallelMode.PARALLEL_1D) + super()._load_from_state_dict(local_state, prefix, *args) + + def _save_to_state_dict(self, destination, prefix, keep_vars): + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + super()._save_to_state_dict(destination, prefix, keep_vars) diff --git a/colossalai/shardformer/layer/layers.py b/colossalai/shardformer/layer/layers.py deleted file mode 100644 index 5dbe28956..000000000 --- a/colossalai/shardformer/layer/layers.py +++ /dev/null @@ -1,722 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -import math -from abc import ABC, abstractmethod -from collections import OrderedDict -from typing import Callable, List, Tuple, Union - -import torch -import torch.distributed as dist -import torch.nn as nn -import torch.nn.functional as F -from torch import Tensor -from torch.distributed import ProcessGroup -from torch.nn.parameter import Parameter - -from colossalai.communication import broadcast -from colossalai.context import ParallelMode, seed -from colossalai.core import global_context as gpc -from colossalai.global_variables import tensor_parallel_env as env -from colossalai.kernel import LayerNorm -from colossalai.nn import init as init -from colossalai.nn.layer.base_layer import ParallelLayer -from colossalai.nn.layer.colossalai_layer._utils import ColossalaiModule -from colossalai.nn.layer.parallel_1d._utils import get_parallel_input, reduce_grad, set_parallel_input -from colossalai.nn.layer.utils import divide, set_tensor_parallel_attribute_by_partition -from colossalai.nn.layer.vanilla import VanillaLayerNorm, VanillaPatchEmbedding -from colossalai.tensor.d_tensor.api import shard_colwise, shard_rowwise -from colossalai.utils.checkpointing import ( - broadcast_state_dict, - gather_tensor_parallel_state_dict, - partition_tensor_parallel_state_dict, -) -from colossalai.utils.cuda import get_current_device - -from ._operation import ( - gather_forward_split_backward, - linear_with_async_comm, - reduce_input, - split_forward_gather_backward, -) -from .utils import create_randomizer_with_offset - -Fast_LN = None -try: - from apex.contrib.layer_norm.layer_norm import FastLayerNorm - Fast_LN = FastLayerNorm -except ImportError: - pass - - -class ParallelModule(nn.Module, ABC): - - @abstractmethod - def from_native_module(module: nn.Module, - process_group: Union[ProcessGroup, List[ProcessGroup]] = None) -> "ParallelModule": - """ - Convert a native PyTorch module to a parallelized module. - - Args: - module (nn.Module): the module to be converted. - process_group (ProcessGroup or list[ProcessGroup]): the process group(s) to be used for communication. - If this is a list, the process group at the ith index of the list will correspond to the process group - in the ith axis of the device mesh. Defaults to None, which means the global process group. - """ - pass - - -class Linear1D_Col(ParallelModule): - r"""Linear layer with column parallelism. - - The linear layer is defined as :math:`Y = XA + b`. A is parallelized along - its second dimension as :math:`A = [A_1, ..., A_p]`. - - Args: - in_features (int): size of each input sample. - out_features (int): size of each output sample. - bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. - dtype (`torch.dtype`): The dtype of parameters, defaults to None. - device (`torch.device`): The device of parameters, defaults to None. - process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None. - gather_output (bool, optional): If true, call all-gather on output and make Y available - to all GPUs, otherwise, every GPU will have its output - which is :math:`Y_i = XA_i`, defaults to False - skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, - which is preserved for kernel fusion, defaults to False - weight_initializer (`typing.Callable`): - The initializer of weight, defaults to kaiming uniform initializer. - bias_initializer (`typing.Callable`): - The initializer of bias, defaults to xavier uniform initializer. - - More details about ``initializer`` please refer to - `init `_. - """ - - def __init__(self, - in_features: int, - out_features: int, - bias: bool = True, - dtype: torch.dtype = None, - device: torch.device = None, - process_group: ProcessGroup = None, - gather_output: bool = False, - skip_bias_add: bool = False, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): - super().__init__() - - # Keep input parameters - self.in_features = in_features - self.out_features = out_features - self.gather_output = gather_output - self.skip_bias_add = skip_bias_add - self.device = device - self.process_group = process_group - self.num_partitions = dist.get_world_size(self.process_group) - - if skip_bias_add and not bias: - raise ValueError('cannot skip bias addition if bias is None') - - self.out_features_per_partition = divide(out_features, self.num_partitions) - - # Parameters. - # Initialize weight. - if device is None: - device = get_current_device() - factory_kwargs = {'device': device, 'dtype': dtype} - self.weight = Parameter(torch.empty(self.out_features_per_partition, self.in_features, **factory_kwargs)) - - if bias: - self.bias = Parameter(torch.empty(self.out_features_per_partition, **factory_kwargs)) - else: - self.bias = None - - # offset the seed with randomizer index and rank - seed = torch.random.initial_seed() - self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) - - with self.randomizer.fork_rng(enable_cpu=True): - self.reset_parameters(weight_initializer, bias_initializer) - - @staticmethod - def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, - **kwargs) -> ParallelModule: - r""" - Convert a native PyTorch linear layer to a parallelized linear layer. - """ - # get the attributes - in_features = module.in_features - out_features = module.out_features - bias = module.bias is not None - device = module.weight.device - - # ensure only one process group is passed - if isinstance(process_group, (list, tuple)): - assert len(process_group) == 1, \ - f'Expected only one process group, got {len(process_group)}.' - process_group = process_group[0] - - linear_1d = Linear1D_Col(in_features=in_features, - out_features=out_features, - bias=bias, - device=device, - process_group=process_group, - *args, - **kwargs) - - # TODO: copy the sharded weights - with torch.no_grad(): - # the weigh to the linear layer is a transpose - # thus shard on row is equal to shard on column - sharded_weight = shard_rowwise(module.weight.data, process_group) - linear_1d.weight.data.copy_(sharded_weight) - if bias: - sharded_bias = shard_colwise(module.bias.data, process_group) - linear_1d.bias.copy_(sharded_bias) - - return linear_1d - - def reset_parameters(self, weight_initializer, bias_initializer) -> None: - fan_in, fan_out = self.in_features, self.out_features - weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) - if self.bias is not None: - bias_initializer(self.bias, fan_in=fan_in) - - def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: - assert input_.shape[-1] == self.weight.shape[-1], \ - 'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format( - input_.shape, self.weight.shape, self.weight.shape[-1]) - # Set up backprop all-reduce. - # input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D) - input_parallel = input_ - # Matrix multiply. - bias = self.bias if not self.skip_bias_add else None - output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True) - - if self.gather_output: - # All-gather across the partitions. - output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group) - else: - output = output_parallel - - if self.skip_bias_add: - return output, self.bias - else: - return output - - -class Linear1D_Row(ParallelModule): - r""" Linear layer with row parallelism - - Args: - in_features (int): size of each input sample. - out_features (int): size of each output sample. - bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. - dtype (`torch.dtype`): The dtype of parameters, defaults to None. - parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False. - skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, - which is preserved for kernel fusion, defaults to False - weight_initializer (:class:`typing.Callable`, optional): - The initializer of weight, defaults to kaiming uniform initializer. - bias_initializer (:class:`typing.Callable`, optional): - The initializer of bias, defaults to xavier uniform initializer. - - More details about ``initializer`` please refer to - `init `_. - """ - - def __init__(self, - in_features: int, - out_features: int, - bias: bool = True, - dtype: torch.dtype = None, - device: torch.device = None, - process_group: ProcessGroup = None, - parallel_input: bool = True, - skip_bias_add: bool = False, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), - stream_chunk_num: int = 1): - super().__init__() - - self.stream_chunk_num = stream_chunk_num - - # Keep input parameters - self.in_features = in_features - self.out_features = out_features - self.parallel_input = parallel_input - self.skip_bias_add = skip_bias_add - self.process_group = process_group - self.num_partitions = dist.get_world_size(self.process_group) - - if skip_bias_add and not bias: - raise ValueError('cannot skip bias addition if bias is None') - - # Divide the weight matrix along the last dimension. - self.input_size_per_partition = divide(in_features, self.num_partitions) - - # Parameters. - # Initialize weight. - if device is None: - device = get_current_device() - - factory_kwargs = {'device': device, 'dtype': dtype} - self.weight = Parameter(torch.empty(self.out_features, self.input_size_per_partition, **factory_kwargs)) - - if self.stream_chunk_num > 1: - # TODO() work for inference only - self.chunk_weight() - if bias: - self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs)) - else: - self.bias = None - - # offset the seed with randomizer index and rank - seed = torch.random.initial_seed() - self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) - - with self.randomizer.fork_rng(enable_cpu=True): - self.reset_parameters(weight_initializer, bias_initializer) - - @staticmethod - def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, - **kwargs) -> ParallelModule: - r""" - Convert a native PyTorch linear layer to a parallelized linear layer. - """ - # get the attributes - in_features = module.in_features - out_features = module.out_features - bias = module.bias is not None - device = module.weight.device - - # ensure only one process group is passed - if isinstance(process_group, (list, tuple)): - assert len(process_group) == 1, \ - f'Expected only one process group, got {len(process_group)}.' - process_group = process_group[0] - - linear_1d = Linear1D_Row(in_features=in_features, - out_features=out_features, - bias=bias, - device=device, - process_group=process_group, - *args, - **kwargs) - - # TODO: copy the sharded weights - with torch.no_grad(): - # the weigh to the linear layer is a transpose - # thus shard on col is equal to shard on row - sharded_weight = shard_colwise(module.weight.data, process_group) - linear_1d.weight.data.copy_(sharded_weight) - - if bias: - linear_1d.bias.copy_(module.bias.data) - - return linear_1d - - def chunk_weight(self): - self.weight_list = torch.chunk(self.weight, self.stream_chunk_num, dim=0) - - def reset_parameters(self, weight_initializer, bias_initializer) -> None: - fan_in, fan_out = self.in_features, self.out_features - weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) - - if self.bias is not None: - bias_initializer(self.bias, fan_in=fan_in) - if self.process_group is None: - src_rank = 0 - else: - src_rank = dist.distributed_c10d._get_global_rank(self.process_group, 0) - - origin_device = self.bias.device - self.bias = self.bias.cuda() - dist.broadcast(self.bias, src=src_rank, group=self.process_group) - self.bias = self.bias.to(origin_device) - - def forward(self, input_: Tensor) -> Tensor: - # Set up backprop all-reduce. - if self.parallel_input: - assert input_.shape[-1] == self.weight.shape[-1], \ - 'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format( - input_.shape, self.weight.shape, self.weight.shape[-1]) - input_ = input_ - else: - assert divide(input_.shape[-1], self.num_partitions) == self.weight.shape[-1], \ - 'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format( - input_.shape, self.weight.shape, self.weight.shape[-1] * self.num_partitions) - input_ = split_forward_gather_backward(input_, dim=-1, process_group=self.process_group) - - if self.stream_chunk_num > 1: - if self.training: - raise RuntimeError("use stream_chunk_num=1 in Linear1D_Row for training!") - with torch.no_grad(): - output_parallel_list = [None for i in range(self.stream_chunk_num)] - handle_list = [] - for i in range(self.stream_chunk_num): - output_parallel_list[i] = F.linear(input_, self.weight_list[i]) - handle = torch.distributed.all_reduce(output_parallel_list[i], - group=self.process_group, - async_op=True) - handle_list.append(handle) - # output_parallel_list[i] = reduce_input(output_parallel_list[i], ParallelMode.PARALLEL_1D) - for handle in handle_list: - handle.wait() - output = torch.cat(output_parallel_list, dim=-1) - else: - output_parallel = F.linear(input_, self.weight) - # output_parallel = linear_with_async_comm(input_, self.weight, None, ParallelMode.PARALLEL_1D, False) - output = reduce_input(output_parallel, self.process_group) - - if not self.skip_bias_add: - if self.bias is not None: - output = output + self.bias - return output - else: - return output, self.bias - - -class LayerNorm1D(ColossalaiModule): - r""" - Layer Normalization for colossalai - - Args: - normalized_shape (int): input shape from an expected input of size. - :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] - \times \ldots \times \text{normalized_shape}[-1]]` - If a single integer is used, it is treated as a singleton list, and this module will - normalize over the last dimension which is expected to be of that specific size. - eps (float): a value added to the denominator for numerical stability, defaults to 1e-05. - bias (bool, optional): Whether to add a bias, defaults to ``True``. - dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. - """ - - _fast_ln_supported_sizes = [ - 1024, 1536, 2048, 2304, 3072, 3840, 4096, 5120, 6144, 8192, 10240, 12288, 12800, 15360, 16384, 18432, 20480, - 24576, 25600, 30720, 32768, 40960, 49152, 65536 - ] - - def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None): - if Fast_LN is not None and normalized_shape in self._fast_ln_supported_sizes: - norm = Fast_LN(normalized_shape, eps=eps).to(dtype) - else: - norm = None - try: - from apex.normalization import FusedLayerNorm - norm = FusedLayerNorm(normalized_shape, eps=eps).to(dtype) - except ImportError: - norm = LayerNorm(normalized_shape, eps=eps).to(dtype) - super().__init__(norm) - - def _load_from_state_dict(self, state_dict, prefix, *args): - local_state = OrderedDict() - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - # weight - weight = state_dict.pop(weight_key, None) - if weight is not None: - local_state[weight_key] = weight - # bias - bias = state_dict.pop(bias_key, None) - if bias is not None: - local_state[bias_key] = bias - - local_state = broadcast_state_dict(local_state, ParallelMode.PARALLEL_1D) - super()._load_from_state_dict(local_state, prefix, *args) - - def _save_to_state_dict(self, destination, prefix, keep_vars): - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - super()._save_to_state_dict(destination, prefix, keep_vars) - - -class Embedding1D(ParallelModule): - r"""Embedding for 1D parallelism. - - Args: - num_embeddings (int): number of embeddings. - embedding_dim (int): dimension of embedding. - padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient; - therefore, the embedding vector at padding_idx is not updated during training, - i.e. it remains as a fixed “pad”, defaults to None. - dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. - weight_initializer (:class:`typing.Callable`, optional): - he initializer of weight, defaults to normal initializer. - - The ``args`` and ``kwargs`` used in :class:`torch.nn.functional.embedding` should contain: - :: - - max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is - renormalized to have norm max_norm. Note: this will modify weight in-place. - norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2. - scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse - of frequency of the words in the mini-batch. Default False. - sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False. - - More details about ``args`` and ``kwargs`` could be found in - `Embedding `_. - - More details about ``initializer`` please refer to - `init `_ - """ - - def __init__(self, - num_embeddings: int, - embedding_dim: int, - padding_idx: int = None, - dtype: torch.dtype = None, - device: torch.device = None, - process_group: ProcessGroup = None, - gather_output: bool = True, - weight_initializer: Callable = init.normal_(), - *args, - **kwargs): - super().__init__() - - self.num_embeddings = num_embeddings - self.embedding_dim = embedding_dim - self.process_group = process_group - self.num_partitions = dist.get_world_size(process_group) - self.embed_dim_per_partition = divide(embedding_dim, self.num_partitions) - - self.padding_idx = padding_idx - self.embed_args = args - self.embed_kwargs = kwargs - self.gather_output = gather_output - - if device is None: - device = get_current_device() - - self.weight = Parameter(torch.empty((num_embeddings, self.embed_dim_per_partition), device=device, dtype=dtype)) - - # offset the seed with randomizer index and rank - seed = torch.random.initial_seed() - self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) - - with self.randomizer.fork_rng(enable_cpu=True): - self.reset_parameters(weight_initializer) - - @staticmethod - def from_native_module(module: nn.Embedding, - process_group: Union[ProcessGroup, List[ProcessGroup]] = None, - *args, - **kwargs) -> "Embedding1D": - r""" - Build a 1D parallelized Embedding from a native nn.Embedding module. - """ - # get the attributes - num_embedding = module.num_embeddings - embedding_dim = module.embedding_dim - padding_idx = module.padding_idx - max_norm = module.max_norm - norm_type = module.norm_type - scale_grad_by_freq = module.scale_grad_by_freq - sparse = module.sparse - dtype = module.weight.dtype - device = module.weight.device - - # sparse is not support yet - if sparse: - raise NotImplementedError("The Embedding1D module does not support sparse embedding yet.") - - embedding = Embedding1D(num_embeddings=num_embedding, - embedding_dim=embedding_dim, - padding_idx=padding_idx, - process_group=process_group, - dtype=dtype, - device=device, - max_norm=max_norm, - norm_type=norm_type, - scale_grad_by_freq=scale_grad_by_freq, - sparse=sparse, - *args, - **kwargs) - - # copy the weight - with torch.no_grad(): - sharded_weight = shard_colwise(module.weight.data, process_group) - embedding.weight.copy_(sharded_weight) - - return embedding - - def reset_parameters(self, weight_initializer) -> None: - fan_in, fan_out = self.num_embeddings, self.embedding_dim - weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) - self._fill_padding_idx_with_zero() - - def _fill_padding_idx_with_zero(self) -> None: - if self.padding_idx is not None: - with torch.no_grad(): - self.weight[self.padding_idx].fill_(0) - - def forward(self, input_: Tensor) -> Tensor: - output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) - - if self.gather_output: - output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group) - return output - else: - return output_parallel - - -class VocabParallelEmbedding1D(ParallelLayer): - r"""Embedding parallelized in the vocabulary dimension. - - Args: - num_embeddings (int): number of embeddings. - embedding_dim (int): dimension of embedding. - padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient; - therefore, the embedding vector at padding_idx is not updated during training, - i.e. it remains as a fixed “pad”, defaults to None. - dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. - weight_initializer (:class:`typing.Callable`, optional): - he initializer of weight, defaults to normal initializer. - - The ``args`` and ``kwargs`` used in :class:``torch.nn.functional.embedding`` should contain: - :: - - max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is - renormalized to have norm max_norm. Note: this will modify weight in-place. - norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2. - scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse - of frequency of the words in the mini-batch. Default False. - sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False. - - More details about ``args`` and ``kwargs`` could be found in - `Embedding `_. - - More details about initializer please refer to - `init `_. - """ - - def __init__(self, - num_embeddings: int, - embedding_dim: int, - padding_idx: int = None, - dtype: torch.dtype = None, - device: torch.device = None, - process_group: ProcessGroup = None, - weight_initializer: Callable = init.normal_(), - *args, - **kwargs): - super().__init__() - self.num_embeddings = num_embeddings - self.embedding_dim = embedding_dim - self.padding_idx = padding_idx - self.embed_args = args - self.embed_kwargs = kwargs - self.process_group = process_group - - tensor_parallel_size = dist.get_world_size(group=process_group) - tensor_parallel_rank = dist.get_rank(group=process_group) - - self.num_embeddings_per_partition = divide(num_embeddings, tensor_parallel_size) - self.num_embeddings = self.num_embeddings_per_partition - self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition - self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition - - self.weight = Parameter( - torch.empty((self.num_embeddings_per_partition, self.embedding_dim), device=device, dtype=dtype)) - - # offset the seed with randomizer index and rank - seed = torch.random.initial_seed() - self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) - - with self.randomizer.fork_rng(enable_cpu=True): - self.reset_parameters(weight_initializer) - # self.reset_parameters(weight_initializer) - # self._set_tensor_parallel_attributes() - # set_parallel_input(False) - # env.vocab_parallel = True - - @staticmethod - def from_native_module(module: nn.Embedding, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, - **kwargs) -> ParallelModule: - r""" - Convert a native pytorch embedding module to a parallel module. - """ - # get the origin attributes - num_embeddings = module.num_embeddings - embedding_dim = module.embedding_dim - padding_idx = module.padding_idx - device = module.weight.device - - # ensure only one process group is used - if isinstance(process_group, (list, tuple)): - assert len(process_group) == 1, \ - f'Expected only one process group, got {len(process_group)}.' - process_group = process_group[0] - - # create the parallel module - vocab_embedding_1d = VocabParallelEmbedding1D(num_embeddings=num_embeddings, - embedding_dim=embedding_dim, - padding_idx=padding_idx, - device=device, - process_group=process_group, - *args, - **kwargs) - with torch.no_grad(): - # shard and slice the weight along the vocabulary(num_embeddings) dimension - # the shape of the weight is (num_embeddings, embedding_dim) - shard_weight = shard_rowwise(module.weight.data, process_group) - vocab_embedding_1d.weight.data.copy_(shard_weight) - - return vocab_embedding_1d - - def _set_tensor_parallel_attributes(self): - set_tensor_parallel_attribute_by_partition(self.weight, gpc.tensor_parallel_size) - - def reset_parameters(self, weight_initializer) -> None: - with seed(ParallelMode.TENSOR): - fan_in, fan_out = self.num_embeddings, self.embedding_dim - weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) - self._fill_padding_idx_with_zero() - - def _fill_padding_idx_with_zero(self) -> None: - if self.padding_idx is not None and \ - self.padding_idx >= self.vocab_start_index and self.padding_idx < self.vocab_end_index: - with torch.no_grad(): - self.weight[self.padding_idx - self.vocab_start_index].fill_(0) - - def _load_from_global_state_dict(self, state_dict, prefix, *args): - local_state = OrderedDict() - weight_key = prefix + 'weight' - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - # weight - weight = state_dict.pop(weight_key, None) - if weight is not None: - local_state[weight_key] = weight - - local_state = partition_tensor_parallel_state_dict(local_state, - ParallelMode.PARALLEL_1D, - dims={weight_key: 0}, - partition_states={weight_key: True}) - super()._load_from_global_state_dict(local_state, prefix, *args) - - def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' - local_state = OrderedDict({weight_key: self.weight}) - local_state = gather_tensor_parallel_state_dict(local_state, - ParallelMode.PARALLEL_1D, - dims={weight_key: 0}, - partition_states={weight_key: True}, - keep_vars=keep_vars) - destination.update(local_state) - - def forward(self, input_: Tensor) -> Tensor: - # Build the mask. - input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index) - # Mask the input. - masked_input = input_.clone() - self.vocab_start_index - masked_input[input_mask] = 0 - - output_parallel = F.embedding(masked_input, self.weight, self.padding_idx, *self.embed_args, - **self.embed_kwargs) - - # Mask the output embedding. - output_parallel[input_mask, :] = 0. - # Reduce across all the model parallel GPUs. - output = reduce_input(output_parallel, self.process_group) - return output diff --git a/colossalai/shardformer/layer/linear1d.py b/colossalai/shardformer/layer/linear1d.py new file mode 100644 index 000000000..d59d32df8 --- /dev/null +++ b/colossalai/shardformer/layer/linear1d.py @@ -0,0 +1,346 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import math +from typing import Callable, List, Tuple, Union + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from torch.distributed import ProcessGroup +from torch.nn.parameter import Parameter + +from colossalai.nn import init as init +from colossalai.nn.layer.utils import divide +from colossalai.tensor.d_tensor.api import shard_colwise, shard_rowwise +from colossalai.utils.cuda import get_current_device + +from ._operation import ( + gather_forward_split_backward, + linear_with_async_comm, + reduce_input, + split_forward_gather_backward, +) +from .parallelmodule import ParallelModule +from .utils import create_randomizer_with_offset + +Fast_LN = None +try: + from apex.contrib.layer_norm.layer_norm import FastLayerNorm + Fast_LN = FastLayerNorm +except ImportError: + pass + + +class Linear1D_Col(ParallelModule): + r"""Linear layer with column parallelism. + + The linear layer is defined as :math:`Y = XA + b`. A is parallelized along + its second dimension as :math:`A = [A_1, ..., A_p]`. + + Args: + in_features (int): size of each input sample. + out_features (int): size of each output sample. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (`torch.dtype`): The dtype of parameters, defaults to None. + device (`torch.device`): The device of parameters, defaults to None. + process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None. + gather_output (bool, optional): If true, call all-gather on output and make Y available + to all GPUs, otherwise, every GPU will have its output + which is :math:`Y_i = XA_i`, defaults to False + skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, + which is preserved for kernel fusion, defaults to False + weight_initializer (`typing.Callable`): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (`typing.Callable`): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__(self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + device: torch.device = None, + process_group: ProcessGroup = None, + gather_output: bool = False, + skip_bias_add: bool = False, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + super().__init__() + + # Keep input parameters + self.in_features = in_features + self.out_features = out_features + self.gather_output = gather_output + self.skip_bias_add = skip_bias_add + self.device = device + self.process_group = process_group + self.num_partitions = dist.get_world_size(self.process_group) + + if skip_bias_add and not bias: + raise ValueError('cannot skip bias addition if bias is None') + + self.out_features_per_partition = divide(out_features, self.num_partitions) + + # Parameters. + # Initialize weight. + if device is None: + device = get_current_device() + factory_kwargs = {'device': device, 'dtype': dtype} + self.weight = Parameter(torch.empty(self.out_features_per_partition, self.in_features, **factory_kwargs)) + + if bias: + self.bias = Parameter(torch.empty(self.out_features_per_partition, **factory_kwargs)) + else: + self.bias = None + + # offset the seed with randomizer index and rank + seed = torch.random.initial_seed() + self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) + + with self.randomizer.fork_rng(enable_cpu=True): + self.reset_parameters(weight_initializer, bias_initializer) + + @staticmethod + def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, + **kwargs) -> ParallelModule: + r""" + Convert a native PyTorch linear layer to a parallelized linear layer. + """ + # get the attributes + in_features = module.in_features + out_features = module.out_features + bias = module.bias is not None + device = module.weight.device + + # ensure only one process group is passed + if isinstance(process_group, (list, tuple)): + assert len(process_group) == 1, \ + f'Expected only one process group, got {len(process_group)}.' + process_group = process_group[0] + + linear_1d = Linear1D_Col(in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + process_group=process_group, + *args, + **kwargs) + + # TODO: copy the sharded weights + with torch.no_grad(): + # the weigh to the linear layer is a transpose + # thus shard on row is equal to shard on column + sharded_weight = shard_rowwise(module.weight.data, process_group) + linear_1d.weight.data.copy_(sharded_weight) + if bias: + sharded_bias = shard_colwise(module.bias.data, process_group) + linear_1d.bias.copy_(sharded_bias) + + return linear_1d + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + fan_in, fan_out = self.in_features, self.out_features + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + + def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: + assert input_.shape[-1] == self.weight.shape[-1], \ + 'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format( + input_.shape, self.weight.shape, self.weight.shape[-1]) + # Set up backprop all-reduce. + # input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D) + input_parallel = input_ + # Matrix multiply. + bias = self.bias if not self.skip_bias_add else None + output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True) + + if self.gather_output: + # All-gather across the partitions. + output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group) + else: + output = output_parallel + + if self.skip_bias_add: + return output, self.bias + else: + return output + + +class Linear1D_Row(ParallelModule): + r""" Linear layer with row parallelism + + Args: + in_features (int): size of each input sample. + out_features (int): size of each output sample. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (`torch.dtype`): The dtype of parameters, defaults to None. + parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False. + skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, + which is preserved for kernel fusion, defaults to False + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__(self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + device: torch.device = None, + process_group: ProcessGroup = None, + parallel_input: bool = True, + skip_bias_add: bool = False, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + stream_chunk_num: int = 1): + super().__init__() + + self.stream_chunk_num = stream_chunk_num + + # Keep input parameters + self.in_features = in_features + self.out_features = out_features + self.parallel_input = parallel_input + self.skip_bias_add = skip_bias_add + self.process_group = process_group + self.num_partitions = dist.get_world_size(self.process_group) + + if skip_bias_add and not bias: + raise ValueError('cannot skip bias addition if bias is None') + + # Divide the weight matrix along the last dimension. + self.input_size_per_partition = divide(in_features, self.num_partitions) + + # Parameters. + # Initialize weight. + if device is None: + device = get_current_device() + + factory_kwargs = {'device': device, 'dtype': dtype} + self.weight = Parameter(torch.empty(self.out_features, self.input_size_per_partition, **factory_kwargs)) + + if self.stream_chunk_num > 1: + # TODO() work for inference only + self.chunk_weight() + if bias: + self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs)) + else: + self.bias = None + + # offset the seed with randomizer index and rank + seed = torch.random.initial_seed() + self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) + + with self.randomizer.fork_rng(enable_cpu=True): + self.reset_parameters(weight_initializer, bias_initializer) + + @staticmethod + def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, + **kwargs) -> ParallelModule: + r""" + Convert a native PyTorch linear layer to a parallelized linear layer. + """ + # get the attributes + in_features = module.in_features + out_features = module.out_features + bias = module.bias is not None + device = module.weight.device + + # ensure only one process group is passed + if isinstance(process_group, (list, tuple)): + assert len(process_group) == 1, \ + f'Expected only one process group, got {len(process_group)}.' + process_group = process_group[0] + + linear_1d = Linear1D_Row(in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + process_group=process_group, + *args, + **kwargs) + + # TODO: copy the sharded weights + with torch.no_grad(): + # the weigh to the linear layer is a transpose + # thus shard on col is equal to shard on row + sharded_weight = shard_colwise(module.weight.data, process_group) + linear_1d.weight.data.copy_(sharded_weight) + + if bias: + linear_1d.bias.copy_(module.bias.data) + + return linear_1d + + def chunk_weight(self): + self.weight_list = torch.chunk(self.weight, self.stream_chunk_num, dim=0) + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + fan_in, fan_out = self.in_features, self.out_features + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + if self.process_group is None: + src_rank = 0 + else: + src_rank = dist.distributed_c10d._get_global_rank(self.process_group, 0) + + origin_device = self.bias.device + self.bias = self.bias.cuda() + dist.broadcast(self.bias, src=src_rank, group=self.process_group) + self.bias = self.bias.to(origin_device) + + def forward(self, input_: Tensor) -> Tensor: + # Set up backprop all-reduce. + if self.parallel_input: + assert input_.shape[-1] == self.weight.shape[-1], \ + 'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format( + input_.shape, self.weight.shape, self.weight.shape[-1]) + input_ = input_ + else: + assert divide(input_.shape[-1], self.num_partitions) == self.weight.shape[-1], \ + 'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format( + input_.shape, self.weight.shape, self.weight.shape[-1] * self.num_partitions) + input_ = split_forward_gather_backward(input_, dim=-1, process_group=self.process_group) + + if self.stream_chunk_num > 1: + if self.training: + raise RuntimeError("use stream_chunk_num=1 in Linear1D_Row for training!") + with torch.no_grad(): + output_parallel_list = [None for i in range(self.stream_chunk_num)] + handle_list = [] + for i in range(self.stream_chunk_num): + output_parallel_list[i] = F.linear(input_, self.weight_list[i]) + handle = torch.distributed.all_reduce(output_parallel_list[i], + group=self.process_group, + async_op=True) + handle_list.append(handle) + # output_parallel_list[i] = reduce_input(output_parallel_list[i], ParallelMode.PARALLEL_1D) + for handle in handle_list: + handle.wait() + output = torch.cat(output_parallel_list, dim=-1) + else: + output_parallel = F.linear(input_, self.weight) + # output_parallel = linear_with_async_comm(input_, self.weight, None, ParallelMode.PARALLEL_1D, False) + output = reduce_input(output_parallel, self.process_group) + + if not self.skip_bias_add: + if self.bias is not None: + output = output + self.bias + return output + else: + return output, self.bias diff --git a/colossalai/shardformer/layer/linearconv1d.py b/colossalai/shardformer/layer/linearconv1d.py new file mode 100644 index 000000000..4a5cb0707 --- /dev/null +++ b/colossalai/shardformer/layer/linearconv1d.py @@ -0,0 +1,377 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import math +from typing import Callable, List, Tuple, Union + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from torch.distributed import ProcessGroup +from torch.nn.parameter import Parameter + +from colossalai.nn import init as init +from colossalai.nn.layer.utils import divide +from colossalai.tensor.d_tensor.api import shard_colwise, shard_rowwise +from colossalai.utils.cuda import get_current_device + +from ._operation import ( + gather_forward_split_backward, + linear_with_async_comm, + reduce_input, + split_forward_gather_backward, +) +from .parallelmodule import ParallelModule +from .utils import create_randomizer_with_offset + +Fast_LN = None +try: + from apex.contrib.layer_norm.layer_norm import FastLayerNorm + Fast_LN = FastLayerNorm +except ImportError: + pass + + +class LinearConv1D_Col(ParallelModule): + r"""Linear layer with column parallelism. + + The linear layer is defined as :math:`Y = XA + b`. A is parallelized along + its second dimension as :math:`A = [A_1, ..., A_p]`. This layer is used to fit `Conv1D` layer in gpt2 of huggingface. + + Args: + in_features (int): size of each input sample. + out_features (int): size of each output sample. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (`torch.dtype`): The dtype of parameters, defaults to None. + device (`torch.device`): The device of parameters, defaults to None. + process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None. + gather_output (bool, optional): If true, call all-gather on output and make Y available + to all GPUs, otherwise, every GPU will have its output + which is :math:`Y_i = XA_i`, defaults to False + skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, + which is preserved for kernel fusion, defaults to False + weight_initializer (`typing.Callable`): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (`typing.Callable`): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__(self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + device: torch.device = None, + process_group: ProcessGroup = None, + gather_output: bool = False, + skip_bias_add: bool = False, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + super().__init__() + + # Keep input parameters + self.in_features = in_features + self.out_features = out_features + self.gather_output = gather_output + self.skip_bias_add = skip_bias_add + self.device = device + self.process_group = process_group + self.num_partitions = dist.get_world_size(self.process_group) + + if skip_bias_add and not bias: + raise ValueError('cannot skip bias addition if bias is None') + + self.out_features_per_partition = divide(out_features, self.num_partitions) + + # Parameters. + # Initialize weight. + if device is None: + device = get_current_device() + factory_kwargs = {'device': device, 'dtype': dtype} + self.weight = Parameter(torch.empty(self.out_features_per_partition, self.in_features, **factory_kwargs)) + + if bias: + self.bias = Parameter(torch.empty(self.out_features_per_partition, **factory_kwargs)) + else: + self.bias = None + + # offset the seed with randomizer index and rank + seed = torch.random.initial_seed() + self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) + + with self.randomizer.fork_rng(enable_cpu=True): + self.reset_parameters(weight_initializer, bias_initializer) + + @staticmethod + def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], n_cast: int, + *args, **kwargs) -> ParallelModule: + r""" + Convert a huggingface layer `Conv1D` in gpt2 to a parallelized linear layer. + """ + # get the attributes + in_features = module.weight.shape[0] + out_features = module.weight.shape[1] + bias = module.bias is not None + device = module.weight.device + + # ensure only one process group is passed + if isinstance(process_group, (list, tuple)): + assert len(process_group) == 1, \ + f'Expected only one process group, got {len(process_group)}.' + process_group = process_group[0] + + linear_1d = LinearConv1D_Col(in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + process_group=process_group, + *args, + **kwargs) + + # TODO: copy the sharded weights + with torch.no_grad(): + # the weigh to the linear layer is a transpose + # thus shard on row is equal to shard on column + + # first rearange the order of weight and bias + world_size = dist.get_world_size(group=process_group) + order = torch.arange(world_size * n_cast) + new_order = [] + for i in range(world_size): + new_order.append(order[i::world_size]) + new_order = torch.cat(new_order) + + weight_chunks = torch.chunk(module.weight.data, world_size * n_cast, dim=1) + rearanged_weight_chunks = [weight_chunks[i] for i in new_order] + rearanged_weight = torch.cat(rearanged_weight_chunks, dim=1) + sharded_weight = shard_colwise(rearanged_weight, process_group) + linear_1d.weight.data.copy_(sharded_weight.T.contiguous()) + + if bias: + bias_chunks = torch.chunk(module.bias.data, world_size * n_cast, dim=0) + rearanged_bias_chunks = [bias_chunks[i] for i in new_order] + rearanged_bias = torch.cat(rearanged_bias_chunks, dim=0) + sharded_bias = shard_colwise(rearanged_bias, process_group) + linear_1d.bias.copy_(sharded_bias.contiguous()) + + return linear_1d + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + fan_in, fan_out = self.in_features, self.out_features + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + + def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: + assert input_.shape[-1] == self.weight.shape[-1], \ + 'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format( + input_.shape, self.weight.shape, self.weight.shape[-1]) + # Set up backprop all-reduce. + # input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D) + input_parallel = input_ + # Matrix multiply. + bias = self.bias if not self.skip_bias_add else None + output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True) + + if self.gather_output: + # All-gather across the partitions. + output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group) + else: + output = output_parallel + + if self.skip_bias_add: + return output, self.bias + else: + return output + + +class LinearConv1D_Row(ParallelModule): + r""" Linear layer with row parallelism + + Args: + in_features (int): size of each input sample. + out_features (int): size of each output sample. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (`torch.dtype`): The dtype of parameters, defaults to None. + parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False. + skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, + which is preserved for kernel fusion, defaults to False + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__(self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + device: torch.device = None, + process_group: ProcessGroup = None, + parallel_input: bool = True, + skip_bias_add: bool = False, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + stream_chunk_num: int = 1): + super().__init__() + + self.stream_chunk_num = stream_chunk_num + + # Keep input parameters + self.in_features = in_features + self.out_features = out_features + self.parallel_input = parallel_input + self.skip_bias_add = skip_bias_add + self.process_group = process_group + self.num_partitions = dist.get_world_size(self.process_group) + + if skip_bias_add and not bias: + raise ValueError('cannot skip bias addition if bias is None') + + # Divide the weight matrix along the last dimension. + self.input_size_per_partition = divide(in_features, self.num_partitions) + + # Parameters. + # Initialize weight. + if device is None: + device = get_current_device() + + factory_kwargs = {'device': device, 'dtype': dtype} + self.weight = Parameter(torch.empty(self.out_features, self.input_size_per_partition, **factory_kwargs)) + + if self.stream_chunk_num > 1: + # TODO() work for inference only + self.chunk_weight() + if bias: + self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs)) + else: + self.bias = None + + # offset the seed with randomizer index and rank + seed = torch.random.initial_seed() + self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) + + with self.randomizer.fork_rng(enable_cpu=True): + self.reset_parameters(weight_initializer, bias_initializer) + + @staticmethod + def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], n_cast: int, + *args, **kwargs) -> ParallelModule: + r""" + Convert a native PyTorch linear layer to a parallelized linear layer. + """ + # get the attributes + in_features = module.weight.shape[0] + out_features = module.weight.shape[1] + bias = module.bias is not None + device = module.weight.device + + # ensure only one process group is passed + if isinstance(process_group, (list, tuple)): + assert len(process_group) == 1, \ + f'Expected only one process group, got {len(process_group)}.' + process_group = process_group[0] + + linear_1d = LinearConv1D_Row(in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + process_group=process_group, + *args, + **kwargs) + + # TODO: copy the sharded weights + with torch.no_grad(): + # the weigh to the linear layer is a transpose + # thus shard on col is equal to shard on row + + # first rearange the order of weight and bias + world_size = dist.get_world_size(group=process_group) + order = torch.arange(world_size * n_cast) + new_order = [] + for i in range(world_size): + new_order.append(order[i::world_size]) + new_order = torch.cat(new_order) + + weight_chunks = torch.chunk(module.weight.data, world_size * n_cast, dim=0) + rearanged_weight_chunks = [weight_chunks[i] for i in new_order] + rearanged_weight = torch.cat(rearanged_weight_chunks, dim=0) + sharded_weight = shard_rowwise(rearanged_weight, process_group) + linear_1d.weight.data.copy_(sharded_weight.T.contiguous()) + + if bias: + bias_chunks = torch.chunk(module.bias.data, world_size * n_cast, dim=0) + rearanged_bias_chunks = [bias_chunks[i] for i in new_order] + rearanged_bias = torch.cat(rearanged_bias_chunks, dim=0) + linear_1d.bias.copy_(rearanged_bias.contiguous()) + + return linear_1d + + def chunk_weight(self): + self.weight_list = torch.chunk(self.weight, self.stream_chunk_num, dim=0) + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + fan_in, fan_out = self.in_features, self.out_features + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + if self.process_group is None: + src_rank = 0 + else: + src_rank = dist.distributed_c10d._get_global_rank(self.process_group, 0) + + origin_device = self.bias.device + self.bias = self.bias.cuda() + dist.broadcast(self.bias, src=src_rank, group=self.process_group) + self.bias = self.bias.to(origin_device) + + def forward(self, input_: Tensor) -> Tensor: + # Set up backprop all-reduce. + if self.parallel_input: + assert input_.shape[-1] == self.weight.shape[-1], \ + 'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format( + input_.shape, self.weight.shape, self.weight.shape[-1]) + input_ = input_ + else: + assert divide(input_.shape[-1], self.num_partitions) == self.weight.shape[-1], \ + 'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format( + input_.shape, self.weight.shape, self.weight.shape[-1] * self.num_partitions) + input_ = split_forward_gather_backward(input_, dim=-1, process_group=self.process_group) + + if self.stream_chunk_num > 1: + if self.training: + raise RuntimeError("use stream_chunk_num=1 in Linear1D_Row for training!") + with torch.no_grad(): + output_parallel_list = [None for i in range(self.stream_chunk_num)] + handle_list = [] + for i in range(self.stream_chunk_num): + output_parallel_list[i] = F.linear(input_, self.weight_list[i]) + handle = torch.distributed.all_reduce(output_parallel_list[i], + group=self.process_group, + async_op=True) + handle_list.append(handle) + # output_parallel_list[i] = reduce_input(output_parallel_list[i], ParallelMode.PARALLEL_1D) + for handle in handle_list: + handle.wait() + output = torch.cat(output_parallel_list, dim=-1) + else: + output_parallel = F.linear(input_, self.weight) + # output_parallel = linear_with_async_comm(input_, self.weight, None, ParallelMode.PARALLEL_1D, False) + output = reduce_input(output_parallel, self.process_group) + + if not self.skip_bias_add: + if self.bias is not None: + output = output + self.bias + return output + else: + return output, self.bias diff --git a/colossalai/shardformer/layer/parallelmodule.py b/colossalai/shardformer/layer/parallelmodule.py new file mode 100644 index 000000000..3d19bbea7 --- /dev/null +++ b/colossalai/shardformer/layer/parallelmodule.py @@ -0,0 +1,35 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from abc import ABC, abstractmethod +from typing import List, Union + +import torch.nn as nn +from torch.distributed import ProcessGroup + +from colossalai.global_variables import tensor_parallel_env as env +from colossalai.nn import init as init + +Fast_LN = None +try: + from apex.contrib.layer_norm.layer_norm import FastLayerNorm + Fast_LN = FastLayerNorm +except ImportError: + pass + + +class ParallelModule(nn.Module, ABC): + + @abstractmethod + def from_native_module(module: nn.Module, + process_group: Union[ProcessGroup, List[ProcessGroup]] = None) -> "ParallelModule": + """ + Convert a native PyTorch module to a parallelized module. + + Args: + module (nn.Module): the module to be converted. + process_group (ProcessGroup or list[ProcessGroup]): the process group(s) to be used for communication. + If this is a list, the process group at the ith index of the list will correspond to the process group + in the ith axis of the device mesh. Defaults to None, which means the global process group. + """ + pass diff --git a/colossalai/shardformer/layer/vocabparallelembedding1d.py b/colossalai/shardformer/layer/vocabparallelembedding1d.py new file mode 100644 index 000000000..4c325c684 --- /dev/null +++ b/colossalai/shardformer/layer/vocabparallelembedding1d.py @@ -0,0 +1,170 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from collections import OrderedDict +from typing import Callable, List, Union + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from torch.distributed import ProcessGroup +from torch.nn.parameter import Parameter + +from colossalai.context import ParallelMode, seed +from colossalai.nn import init as init +from colossalai.nn.layer.base_layer import ParallelLayer +from colossalai.nn.layer.utils import divide +from colossalai.tensor.d_tensor.api import shard_rowwise +from colossalai.utils.checkpointing import gather_tensor_parallel_state_dict + +from ._operation import reduce_input +from .parallelmodule import ParallelModule +from .utils import create_randomizer_with_offset + +Fast_LN = None +try: + from apex.contrib.layer_norm.layer_norm import FastLayerNorm + Fast_LN = FastLayerNorm +except ImportError: + pass + + +class VocabParallelEmbedding1D(ParallelLayer): + r"""Embedding parallelized in the vocabulary dimension. + + Args: + num_embeddings (int): number of embeddings. + embedding_dim (int): dimension of embedding. + padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient; + therefore, the embedding vector at padding_idx is not updated during training, + i.e. it remains as a fixed “pad”, defaults to None. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + weight_initializer (:class:`typing.Callable`, optional): + he initializer of weight, defaults to normal initializer. + + The ``args`` and ``kwargs`` used in :class:``torch.nn.functional.embedding`` should contain: + :: + + max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is + renormalized to have norm max_norm. Note: this will modify weight in-place. + norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2. + scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse + of frequency of the words in the mini-batch. Default False. + sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False. + + More details about ``args`` and ``kwargs`` could be found in + `Embedding `_. + + More details about initializer please refer to + `init `_. + """ + + def __init__(self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int = None, + dtype: torch.dtype = None, + device: torch.device = None, + process_group: ProcessGroup = None, + weight_initializer: Callable = init.normal_(), + *args, + **kwargs): + super().__init__() + self.num_embeddings = num_embeddings + self.embed_dim = embedding_dim + self.padding_idx = padding_idx + self.embed_args = args + self.embed_kwargs = kwargs + self.process_group = process_group + + tensor_parallel_size = dist.get_world_size(group=process_group) + tensor_parallel_rank = dist.get_rank(group=process_group) + + self.num_embeddings_per_partition = divide(num_embeddings, tensor_parallel_size) + self.num_embeddings = self.num_embeddings_per_partition + self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition + self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition + + self.weight = Parameter( + torch.empty((self.num_embeddings_per_partition, self.embed_dim), device=device, dtype=dtype)) + + # offset the seed with randomizer index and rank + seed = torch.random.initial_seed() + self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) + + with self.randomizer.fork_rng(enable_cpu=True): + self.reset_parameters(weight_initializer) + + @staticmethod + def from_native_module(module: nn.Embedding, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, + **kwargs) -> ParallelModule: + r""" + Convert a native pytorch embedding module to a parallel module. + """ + # get the origin attributes + num_embeddings = module.num_embeddings + embedding_dim = module.embedding_dim + padding_idx = module.padding_idx + device = module.weight.device + + # ensure only one process group is used + if isinstance(process_group, (list, tuple)): + assert len(process_group) == 1, \ + f'Expected only one process group, got {len(process_group)}.' + process_group = process_group[0] + + # create the parallel module + vocab_embedding_1d = VocabParallelEmbedding1D(num_embeddings=num_embeddings, + embedding_dim=embedding_dim, + padding_idx=padding_idx, + device=device, + process_group=process_group, + *args, + **kwargs) + with torch.no_grad(): + # shard and slice the weight along the vocabulary(num_embeddings) dimension + # the shape of the weight is (num_embeddings, embedding_dim) + shard_weight = shard_rowwise(module.weight.data, process_group) + vocab_embedding_1d.weight.data.copy_(shard_weight) + + return vocab_embedding_1d + + def reset_parameters(self, weight_initializer) -> None: + with seed(ParallelMode.TENSOR): + fan_in, fan_out = self.num_embeddings, self.embed_dim + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + self._fill_padding_idx_with_zero() + + def _fill_padding_idx_with_zero(self) -> None: + if self.padding_idx is not None and \ + self.padding_idx >= self.vocab_start_index and self.padding_idx < self.vocab_end_index: + with torch.no_grad(): + self.weight[self.padding_idx - self.vocab_start_index].fill_(0) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + 'weight' + local_state = OrderedDict({weight_key: self.weight}) + local_state = gather_tensor_parallel_state_dict(local_state, + ParallelMode.PARALLEL_1D, + dims={weight_key: 0}, + partition_states={weight_key: True}, + keep_vars=keep_vars) + destination.update(local_state) + + def forward(self, input_: Tensor) -> Tensor: + # Build the mask. + input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index) + # Mask the input. + masked_input = input_.clone() - self.vocab_start_index + masked_input[input_mask] = 0 + + output_parallel = F.embedding(masked_input, self.weight, self.padding_idx, *self.embed_args, + **self.embed_kwargs) + + # Mask the output embedding. + output_parallel[input_mask, :] = 0. + # Reduce across all the model parallel GPUs. + output = reduce_input(output_parallel, self.process_group) + return output diff --git a/colossalai/shardformer/policies/autopolicy.py b/colossalai/shardformer/policies/autopolicy.py index 6ce0b8fb3..5e7a285e3 100644 --- a/colossalai/shardformer/policies/autopolicy.py +++ b/colossalai/shardformer/policies/autopolicy.py @@ -56,6 +56,8 @@ _POLICY_LIST = { PolicyLocation(file_name="t5", class_name="T5EncoderPolicy"), # GPT2 + "transformers.models.gpt2.modeling_gpt2.GPT2Model": + PolicyLocation(file_name="gpt2", class_name="GPT2ModelPolicy"), } @@ -99,4 +101,3 @@ def get_autopolicy(model: nn.Module) -> Policy: else: policy = import_policy(policy_location) return policy() - return policy() diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 06ee9b435..2a204f0de 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -1,7 +1,7 @@ import torch.nn as nn from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer, BertLMPredictionHead -import colossalai.shardformer.layer.layers as col_nn +import colossalai.shardformer.layer as col_nn from colossalai.shardformer.layer.dropout import Dropout1D from ..utils import getattr_, setattr_ @@ -87,15 +87,9 @@ class BertPolicy(Policy): def new_model_class(self): # do nothing - return None + return self.model def postprocess(self): - binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"} - for k, v in binding_map.items(): - param = getattr_(self.model, k) - param = nn.Parameter(param) - setattr_(self.model, k, param) - setattr_(self.model, v, param) return self.model @@ -127,6 +121,15 @@ class BertForPretrainingPolicy(BertPolicy): module_policy.update(addon_module) return module_policy + def postprocess(self): + binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"} + for k, v in binding_map.items(): + param = getattr_(self.model, k) + param = nn.Parameter(param) + setattr_(self.model, k, param) + setattr_(self.model, v, param) + return self.model + # BertForMaskedLM class BertForMaskedLMPolicy(BertPolicy): @@ -149,6 +152,15 @@ class BertForMaskedLMPolicy(BertPolicy): module_policy.update(addon_module) return module_policy + def postprocess(self): + binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"} + for k, v in binding_map.items(): + param = getattr_(self.model, k) + param = nn.Parameter(param) + setattr_(self.model, k, param) + setattr_(self.model, v, param) + return self.model + # BertLMHeadModel class BertLMHeadModelPolicy(BertPolicy): @@ -171,6 +183,15 @@ class BertLMHeadModelPolicy(BertPolicy): module_policy.update(addon_module) return module_policy + def postprocess(self): + binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"} + for k, v in binding_map.items(): + param = getattr_(self.model, k) + param = nn.Parameter(param) + setattr_(self.model, k, param) + setattr_(self.model, v, param) + return self.model + # BertForNextSentencePrediction class BertForNextSentencePredictionPolicy(BertPolicy): diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 0d4342e75..d255325b2 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -1,126 +1,101 @@ -from typing import Any, Callable, Dict, List, Tuple, Type +from typing import Type, Union import torch.nn as nn from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2Model -import colossalai.shardformer.layer.layers as col_nn +import colossalai.shardformer.layer as col_nn +from colossalai.shardformer.layer.dropout import Dropout1D -from .basepolicy import Argument, Col_Layer, Layer, Policy, Row_Layer +from ..utils import getattr_, setattr_ +from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription class GPT2Policy(Policy): - @staticmethod - def argument_policy(config, world_size): + def preprocess(self): + # reshape the embedding layer + r""" + Reshape the Embedding layer to make the embedding dimension divisible by world_size + """ + vocab_size = self.model.config.vocab_size + world_size = self.shard_config.tensor_parallel_size + if vocab_size % world_size != 0: + new_vocab_size = vocab_size + world_size - vocab_size % world_size + self.model.resize_token_embeddings(new_vocab_size) + return self.model + + def module_policy(self): return { GPT2Model: - Argument(attr_dict={}, param_funcs=[ - GPT2Policy.embedding, - ]), + ModulePolicyDescription(attribute_replacement={}, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="wte", + target_module=col_nn.VocabParallelEmbedding1D, + ), + ]), GPT2Block: - Argument( - attr_dict={ - # 1. reduce hidden size - "attn.embed_dim": config.hidden_size // world_size, - "attn.split_size": config.hidden_size // world_size, - "crossattention.embed_dim": config.hidden_size // world_size, - "crossattention.split_size": config.hidden_size // world_size, - # 2. reduce number of heads - "attn.num_heads": config.num_attention_heads // world_size, - "crossattention.num_heads": config.num_attention_heads // world_size, - }, - param_funcs=[ - GPT2Policy.attn_in, - GPT2Policy.attn_out, - GPT2Policy.mlp_in, - GPT2Policy.mlp_out, - ]), + ModulePolicyDescription(attribute_replacement={ + "attn.embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "attn.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + }, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="attn.c_attn", + target_module=col_nn.LinearConv1D_Col, + kwargs={ + "n_cast": 3, + }, + ), + SubModuleReplacementDescription( + suffix="attn.c_proj", + target_module=col_nn.LinearConv1D_Row, + kwargs={ + "n_cast": 1, + }, + ), + SubModuleReplacementDescription( + suffix="mlp.c_fc", + target_module=col_nn.LinearConv1D_Col, + kwargs={ + "n_cast": 1, + }, + ), + SubModuleReplacementDescription( + suffix="mlp.c_proj", + target_module=col_nn.LinearConv1D_Row, + kwargs={ + "n_cast": 1, + }, + ), + SubModuleReplacementDescription( + suffix="attn.attn_dropout", + target_module=col_nn.Dropout1D, + ), + SubModuleReplacementDescription( + suffix="attn.resid_dropout", + target_module=col_nn.Dropout1D, + ), + SubModuleReplacementDescription( + suffix="mlp.dropout", + target_module=col_nn.Dropout1D, + ), + ]) } - @staticmethod - def attn_in() -> List: - return [ - Col_Layer(suffix="attn.c_attn", - weight="weight", - bias="bias", - n_cast=3, - reversed=True, - replace_layer=col_nn.Linear1D_Col), - Col_Layer(suffix="crossattention.c_attn", - weight="weight", - bias="bias", - n_cast=2, - reversed=True, - ignore=True, - replace_layer=col_nn.Linear1D_Col), - Col_Layer(suffix="crossattention.q_attn", - weight="weight", - bias="bias", - reversed=True, - ignore=True, - replace_layer=col_nn.Linear1D_Col) - ] + def new_model_class(self): - @staticmethod - def attn_out() -> List: - return [ - Row_Layer(suffix="attn.c_proj", - weight="weight", - bias="bias", - reversed=True, - replace_layer=col_nn.Linear1D_Row), - Row_Layer(suffix="crossattention.c_proj", - weight="weight", - bias="bias", - reversed=True, - ignore=True, - replace_layer=col_nn.Linear1D_Row) - ] + return self.model - @staticmethod - def mlp_in() -> List: - return [ - Col_Layer(suffix="mlp.c_fc", weight="weight", bias="bias", reversed=True, - replace_layer=col_nn.Linear1D_Col), - ] - - @staticmethod - def mlp_out() -> List: - return [ - Row_Layer(suffix="mlp.c_proj", - weight="weight", - bias="bias", - reversed=True, - replace_layer=col_nn.Linear1D_Row) - ] - - @staticmethod - def embedding() -> List: - return [Col_Layer(suffix="wte", weight="weight", replace_layer=col_nn.VocabParallelEmbedding1D)] + def postprocess(self): + return self.model -from transformers import GPT2LMHeadModel +# GPT2Model +class GPT2ModelPolicy(GPT2Policy): - -class GPT2LMHeadModelPolicy(GPT2Policy): - - @staticmethod - def argument_policy(config, world_size): - base_argument = GPT2Policy.argument_policy(config, world_size) - argument = { - GPT2LMHeadModel: Argument(attr_dict={}, param_funcs=[ - GPT2LMHeadModelPolicy.unembedding, - ]), - } - argument.update(base_argument) - return argument - - @staticmethod - def unembedding() -> List: - return [ - Col_Layer(suffix="lm_head", - weight="weight", - bias="bias", - replace_layer=col_nn.Linear1D_Col, - gather_output=True) - ] + def __init__(self) -> None: + super().__init__() diff --git a/tests/test_shardformer/test_model/test_shard_bert.py b/tests/test_shardformer/test_model/test_shard_bert.py index 54fea0335..043ed1a74 100644 --- a/tests/test_shardformer/test_model/test_shard_bert.py +++ b/tests/test_shardformer/test_model/test_shard_bert.py @@ -108,7 +108,7 @@ def check_bert(rank, world_size, port): backward_lsit = [BertForMaskedLM, BertLMHeadModel] for model_fn in forward_list: - org_model, sharded_model = build_model(model_fn) + org_model, sharded_model = build_model(world_size, model_fn) check_forward(org_model, sharded_model) if model_fn in backward_lsit: diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py new file mode 100644 index 000000000..2f679b83f --- /dev/null +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -0,0 +1,118 @@ +import copy +import os + +import pytest +import torch +from transformers import AutoTokenizer, GPT2Config, GPT2Model + +import colossalai +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer import ShardConfig, ShardFormer +from colossalai.testing import rerun_if_address_is_in_use, spawn + +os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' +CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')),) +tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") + + +def build_model(world_size, model_fn): + config = GPT2Config() + config.attn_pdrop = 0 + config.embd_pdrop = 0 + config.resid_pdrop = 0 + config.summary_first_dropout + + org_model = model_fn(config=config) + org_model_forshard = copy.deepcopy(org_model) + + org_model.to('cuda') + # TODO: no need to transfer to cuda + org_model_forshard.to('cuda') + shard_config = ShardConfig(tensor_parallel_size=world_size,) + shard_former = ShardFormer(shard_config=shard_config) + shard_former.init_distributed() + sharded_model = shard_former.shard_model(org_model_forshard).to('cuda') + + return org_model, sharded_model + + +def check_forward(org_model, sharded_model): + input = 'Hello, my dog is cute' + tokenized_input = tokenizer(input, return_tensors='pt').to('cuda') + + #orgin model + org_model.eval() + org_out = org_model(**tokenized_input) + + #shard model + sharded_model.eval() + shard_out = sharded_model(**tokenized_input) + + assert torch.allclose( + org_out[0], shard_out[0], + atol=1e-5), f"shard model output is not equal to orgin model output\n{org_out[0]}\n{shard_out[0]}" + + +def check_backward(org_model, sharded_model): + # prepare input + input = 'Hello, my dog is cute' + tokenized_input = tokenizer(input, return_tensors='pt').to('cuda') + labels = tokenized_input['input_ids'].clone() + labels[labels == tokenizer.pad_token_id] = -100 + # tokenized_input['labels'] = labels + + #orgin model + org_model.train() + org_out = org_model(**tokenized_input) + org_loss = org_out.loss + org_loss.backward() + org_grad = org_model.h[0].attn.c_attn.weight.grad + + #shard model + sharded_model.train() + shard_out = sharded_model(**tokenized_input) + shard_loss = shard_out.loss + shard_loss.backward() + shard_grad = sharded_model.h[0].attn.c_attn.weight.grad + + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] + shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + + assert torch.allclose(org_loss, shard_loss, + atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" + assert torch.allclose(org_grad, all_shard_grad, + atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}" + + +def check_bert(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + forward_list = [ + GPT2Model, + + # TODO: do not work yet + # BertModel, + # BertForSequenceClassification + # BertForNextSentencePrediction, + ] + backward_lsit = [] + + for model_fn in forward_list: + org_model, sharded_model = build_model(world_size, model_fn) + check_forward(org_model, sharded_model) + + if model_fn in backward_lsit: + check_backward(org_model, sharded_model) + + torch.cuda.empty_cache() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_gpt2(): + spawn(check_bert, 2) + + +if __name__ == "__main__": + test_gpt2()