diff --git a/colossalai/nn/layer/base_layer.py b/colossalai/nn/layer/base_layer.py index 5234b6b1a..4a06bdcb7 100644 --- a/colossalai/nn/layer/base_layer.py +++ b/colossalai/nn/layer/base_layer.py @@ -10,6 +10,7 @@ from colossalai.core import global_context as gpc class ParallelLayer(nn.Module): + global_state_dict: bool = True def __init__(self): diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index e817ea3eb..208a391c3 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -54,10 +54,10 @@ class LinearWithAsyncCommunication(torch.autograd.Function): """ @staticmethod - def forward(ctx, input_, weight, bias, parallel_mode, async_grad_allreduce): + def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce): ctx.save_for_backward(input_, weight) ctx.use_bias = bias is not None - ctx.parallel_mode = parallel_mode + ctx.process_group = process_group ctx.async_grad_allreduce = async_grad_allreduce output = torch.matmul(input_, weight.t()) @@ -74,12 +74,13 @@ class LinearWithAsyncCommunication(torch.autograd.Function): grad_input = grad_output.matmul(weight) grad_output = grad_output.contiguous() # Convert the tensor shapes to 2D for execution compatibility - grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2]) - total_input = total_input.view(total_input.shape[0] * total_input.shape[1], total_input.shape[2]) + if len(grad_output.shape) > 2: + grad_output = grad_output.view(-1, grad_output.shape[-1]) + total_input = total_input.view(-1, total_input.shape[-1]) if ctx.async_grad_allreduce: # Asynchronous all-reduce - handle = dist.all_reduce(grad_input, group=gpc.get_group(ctx.parallel_mode), async_op=True) + handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True) # Delay the start of weight gradient computation shortly (3us) to have # all-reduce scheduled first and have GPU resources allocated _ = torch.empty(1, device=grad_output.device) + 1 @@ -93,5 +94,123 @@ class LinearWithAsyncCommunication(torch.autograd.Function): return grad_input, grad_weight, grad_bias, None, None, None -def linear_with_async_comm(input_, weight, bias, parallel_mode, async_grad_allreduce): - return LinearWithAsyncCommunication.apply(input_, weight, bias, parallel_mode, async_grad_allreduce) +class _SplitForwardGatherBackward(torch.autograd.Function): + """ + Split the input and keep only the corresponding chuck to the rank. + + Args: + input_ (`torch.Tensor`): input matrix. + dim (int): the dimension to perform split and gather + process_group (`torch.distributed.ProcessGroup`): the process group used for collective communication + + """ + + @staticmethod + def forward(ctx, input_, dim, process_group): + ctx.process_group = process_group + ctx.dim = dim + return _split(input_, dim, process_group) + + @staticmethod + def backward(ctx, grad_output): + return _gather(grad_output, ctx.dim, ctx.process_group), None, None + + +class _ReduceInput(torch.autograd.Function): + """ + All-reduce the input from the model parallel region. + + Args: + input_: input matrix. + parallel_mode: parallel mode. + """ + + @staticmethod + def forward(ctx, input_, process_group): + return _reduce(input_, process_group) + + @staticmethod + def backward(ctx, grad_output): + return grad_output, None + + +def _reduce(input_, process_group): + # skip if only one rank involved + if dist.get_world_size(process_group) == 1: + return input_ + else: + dist.all_reduce(input_, group=process_group) + return input_ + + +def _split(input_, dim=-1, process_group=None): + # skip if only one rank involved + world_size = dist.get_world_size(process_group) + if world_size == 1: + return input_ + + # Split along last dimension. + dim_size = input_.size(dim) + assert dim_size % world_size == 0, \ + f'The dimension to split ({dim_size}) is not a multiple of world size ({world_size}), ' \ + f'cannot split tensor evenly' + + tensor_list = torch.split(input_, dim_size // world_size, dim=dim) + rank = dist.get_rank(process_group) + output = tensor_list[rank].contiguous() + + return output + + +def _gather(input_, dim=-1, process_group=None): + # skip if only one rank involved + world_size = dist.get_world_size(process_group) + if world_size == 1: + return input_ + + # all gather + rank = dist.get_rank(process_group) + tensor_list = [torch.empty_like(input_) for _ in range(world_size)] + tensor_list[rank] = input_ + torch.distributed.all_gather(tensor_list, input_, group=process_group) + + # concat + output = torch.cat(tensor_list, dim=dim).contiguous() + + return output + + +class _GatherForwardSplitBackward(torch.autograd.Function): + """Gather the input from model parallel region and concatenate. + + Args: + input_: input matrix. + parallel_mode: parallel mode. + dim: dimension + """ + + @staticmethod + def forward(ctx, input_, dim, process_group): + ctx.process_group = process_group + ctx.dim = dim + return _gather(input_, dim, process_group) + + @staticmethod + def backward(ctx, grad_output): + return _split(grad_output, ctx.dim, ctx.process_group), None, None + + +def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce): + return LinearWithAsyncCommunication.apply(input_, weight, bias, process_group, async_grad_allreduce) + + +def gather_forward_split_backward(input_, dim, process_group): + return _GatherForwardSplitBackward.apply(input_, dim, process_group) + + +def split_forward_gather_backward(input_, dim, process_group): + return _SplitForwardGatherBackward.apply(input_, dim, process_group) + + +def reduce_input(input_, process_group): + return _ReduceInput.apply(input_, process_group) diff --git a/colossalai/shardformer/layer/dropout.py b/colossalai/shardformer/layer/dropout.py index 0f653a9be..5d295be6b 100644 --- a/colossalai/shardformer/layer/dropout.py +++ b/colossalai/shardformer/layer/dropout.py @@ -1,58 +1,20 @@ -import os -from contextlib import contextmanager - import torch +import torch.distributed as dist import torch.nn as nn - -class SeedManager: - """ - This class is a random state manager to change random state for different random seed. - - """ - - def __init__(self): - original_state = torch.cuda.get_rng_state() - # TODO: unify this seed manager with the colossalai.context.random - seed = os.getpid() - torch.cuda.manual_seed(int(seed)) - self.dropout_state = torch.cuda.get_rng_state() - torch.cuda.set_rng_state(original_state) - - def set_mode(self, rng_state): - torch.cuda.set_rng_state(rng_state) - - def get_current_mode(self): - current_state = torch.cuda.get_rng_state() - return current_state - - @contextmanager - def dropout_mode(self): - """ - This is a context manager to change the dropout state and recover the original state. - - Usage: - :: - >>> with _seed_manager.dropout_mode(): - >>> input = super().forward(input) - """ - try: - current_mode = self.get_current_mode() - yield self.set_mode(self.dropout_state) - finally: - self.dropout_state = self.get_current_mode() - self.set_mode(current_mode) - - -_seed_manager = SeedManager() +from .utils import create_randomizer_with_offset class Dropout1D(nn.Dropout): - def __init__(self, p=0.5, inplace=False): + def __init__(self, p=0.5, inplace=False, process_group=None): super().__init__(p, inplace) + # offset the seed with randomizer index and rank + seed = torch.random.initial_seed() + self.randomizer = create_randomizer_with_offset(seed, process_group=process_group) + def forward(self, input): - with _seed_manager.dropout_mode(): + with self.randomizer.fork_rng(): input = super().forward(input) return input diff --git a/colossalai/shardformer/layer/layers.py b/colossalai/shardformer/layer/layers.py index a9f3cf5ad..2ad6523c9 100644 --- a/colossalai/shardformer/layer/layers.py +++ b/colossalai/shardformer/layer/layers.py @@ -2,12 +2,16 @@ # -*- encoding: utf-8 -*- import math +from abc import ABC, abstractmethod from collections import OrderedDict -from typing import Callable, Tuple +from typing import Callable, List, Tuple, Union import torch +import torch.distributed as dist +import torch.nn as nn import torch.nn.functional as F from torch import Tensor +from torch.distributed import ProcessGroup from torch.nn.parameter import Parameter from colossalai.communication import broadcast @@ -22,13 +26,11 @@ from colossalai.nn.layer.parallel_1d._utils import ( gather_forward_split_backward, get_parallel_input, reduce_grad, - reduce_input, set_parallel_input, - split_forward_gather_backward, ) from colossalai.nn.layer.utils import divide, set_tensor_parallel_attribute_by_partition from colossalai.nn.layer.vanilla import VanillaLayerNorm, VanillaPatchEmbedding -from colossalai.registry import LAYERS +from colossalai.tensor.d_tensor.api import shard_colwise, shard_rowwise from colossalai.utils.checkpointing import ( broadcast_state_dict, gather_tensor_parallel_state_dict, @@ -36,7 +38,13 @@ from colossalai.utils.checkpointing import ( ) from colossalai.utils.cuda import get_current_device -from ._operation import linear_with_async_comm +from ._operation import ( + gather_forward_split_backward, + linear_with_async_comm, + reduce_input, + split_forward_gather_backward, +) +from .utils import create_randomizer_with_offset Fast_LN = None try: @@ -46,17 +54,172 @@ except ImportError: pass -# @LAYERS.register_module -class Linear1D(ColossalaiModule): - r"""Linear layer for 1D parallelism. +class ParallelModule(nn.Module, ABC): + + @abstractmethod + def from_native_module(module: nn.Module, + process_group: Union[ProcessGroup, List[ProcessGroup]] = None) -> "ParallelModule": + """ + Convert a native PyTorch module to a parallelized module. + + Args: + module (nn.Module): the module to be converted. + process_group (ProcessGroup or list[ProcessGroup]): the process group(s) to be used for communication. + If this is a list, the process group at the ith index of the list will correspond to the process group + in the ith axis of the device mesh. Defaults to None, which means the global process group. + """ + pass + + +class Linear1D_Col(ParallelModule): + r"""Linear layer with column parallelism. + + The linear layer is defined as :math:`Y = XA + b`. A is parallelized along + its second dimension as :math:`A = [A_1, ..., A_p]`. Args: in_features (int): size of each input sample. out_features (int): size of each output sample. bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. - dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. - gather_output (bool, optional): Whether to call all-gather on output, defaults to False. - skip_bias_add (bool, optional): If set to ``True``, it will skip bias add for linear layer, + dtype (`torch.dtype`): The dtype of parameters, defaults to None. + device (`torch.device`): The device of parameters, defaults to None. + process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None. + gather_output (bool, optional): If true, call all-gather on output and make Y available + to all GPUs, otherwise, every GPU will have its output + which is :math:`Y_i = XA_i`, defaults to False + skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, + which is preserved for kernel fusion, defaults to False + weight_initializer (`typing.Callable`): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (`typing.Callable`): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__(self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + device: torch.device = None, + process_group: ProcessGroup = None, + gather_output: bool = False, + skip_bias_add: bool = False, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + super().__init__() + + # Keep input parameters + self.in_features = in_features + self.out_features = out_features + self.gather_output = gather_output + self.skip_bias_add = skip_bias_add + self.device = device + self.process_group = process_group + self.num_partitions = dist.get_world_size(self.process_group) + + if skip_bias_add and not bias: + raise ValueError('cannot skip bias addition if bias is None') + + self.out_features_per_partition = divide(out_features, self.num_partitions) + + # Parameters. + # Initialize weight. + if device is None: + device = get_current_device() + factory_kwargs = {'device': device, 'dtype': dtype} + self.weight = Parameter(torch.empty(self.out_features_per_partition, self.in_features, **factory_kwargs)) + + if bias: + self.bias = Parameter(torch.empty(self.out_features_per_partition, **factory_kwargs)) + else: + self.bias = None + + # offset the seed with randomizer index and rank + seed = torch.random.initial_seed() + self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) + + with self.randomizer.fork_rng(enable_cpu=True): + self.reset_parameters(weight_initializer, bias_initializer) + + def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, + **kwargs) -> ParallelModule: + r""" + Convert a native PyTorch linear layer to a parallelized linear layer. + """ + # get the attributes + in_features = module.in_features + out_features = module.out_features + bias = module.bias is not None + device = module.weight.device + + # ensure only one process group is passed + if isinstance(process_group, (list, tuple)): + assert len(process_group) == 1, \ + f'Expected only one process group, got {len(process_group)}.' + process_group = process_group[0] + + linear_1d = Linear1D_Col(in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + process_group=process_group, + *args, + **kwargs) + + # TODO: copy the sharded weights + with torch.no_grad(): + # the weigh to the linear layer is a transpose + # thus shard on row is equal to shard on column + sharded_weight = shard_rowwise(module.weight.data, process_group) + linear_1d.weight.data.copy_(sharded_weight) + if bias: + sharded_bias = shard_colwise(module.bias.data, process_group) + linear_1d.bias.copy_(sharded_bias) + + return linear_1d + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + fan_in, fan_out = self.in_features, self.out_features + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + + def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: + assert input_.shape[-1] == self.weight.shape[-1], \ + 'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format( + input_.shape, self.weight.shape, self.weight.shape[-1]) + # Set up backprop all-reduce. + # input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D) + input_parallel = input_ + # Matrix multiply. + bias = self.bias if not self.skip_bias_add else None + output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True) + + if self.gather_output: + # All-gather across the partitions. + output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group) + else: + output = output_parallel + + if self.skip_bias_add: + return output, self.bias + else: + return output + + +class Linear1D_Row(ParallelModule): + r""" Linear layer with row parallelism + + Args: + in_features (int): size of each input sample. + out_features (int): size of each output sample. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (`torch.dtype`): The dtype of parameters, defaults to None. + parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False. + skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, which is preserved for kernel fusion, defaults to False weight_initializer (:class:`typing.Callable`, optional): The initializer of weight, defaults to kaiming uniform initializer. @@ -72,32 +235,149 @@ class Linear1D(ColossalaiModule): out_features: int, bias: bool = True, dtype: torch.dtype = None, - gather_output: bool = False, + device: torch.device = None, + process_group: ProcessGroup = None, + parallel_input: bool = True, skip_bias_add: bool = False, weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): - parallel_input = get_parallel_input() - if not parallel_input and not gather_output: - layer = Linear1D_Col(in_features, - out_features, - bias=bias, - dtype=dtype, - skip_bias_add=skip_bias_add, - weight_initializer=weight_initializer, - bias_initializer=bias_initializer) + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + stream_chunk_num: int = 1): + super().__init__() + + self.stream_chunk_num = stream_chunk_num + + # Keep input parameters + self.in_features = in_features + self.out_features = out_features + self.parallel_input = parallel_input + self.skip_bias_add = skip_bias_add + self.process_group = process_group + self.num_partitions = dist.get_world_size(self.process_group) + + if skip_bias_add and not bias: + raise ValueError('cannot skip bias addition if bias is None') + + # Divide the weight matrix along the last dimension. + self.input_size_per_partition = divide(in_features, self.num_partitions) + + # Parameters. + # Initialize weight. + if device is None: + device = get_current_device() + + factory_kwargs = {'device': device, 'dtype': dtype} + self.weight = Parameter(torch.empty(self.out_features, self.input_size_per_partition, **factory_kwargs)) + + if self.stream_chunk_num > 1: + # TODO() work for inference only + self.chunk_weight() + if bias: + self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs)) else: - layer = Linear1D_Row(in_features, - out_features, + self.bias = None + + # offset the seed with randomizer index and rank + seed = torch.random.initial_seed() + self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) + + with self.randomizer.fork_rng(enable_cpu=True): + self.reset_parameters(weight_initializer, bias_initializer) + + @staticmethod + def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, + **kwargs) -> ParallelModule: + r""" + Convert a native PyTorch linear layer to a parallelized linear layer. + """ + # get the attributes + in_features = module.in_features + out_features = module.out_features + bias = module.bias is not None + device = module.weight.device + + # ensure only one process group is passed + if isinstance(process_group, (list, tuple)): + assert len(process_group) == 1, \ + f'Expected only one process group, got {len(process_group)}.' + process_group = process_group[0] + + linear_1d = Linear1D_Row(in_features=in_features, + out_features=out_features, bias=bias, - dtype=dtype, - parallel_input=parallel_input, - skip_bias_add=skip_bias_add, - weight_initializer=weight_initializer, - bias_initializer=bias_initializer) - super().__init__(layer) + device=device, + process_group=process_group, + *args, + **kwargs) + + # TODO: copy the sharded weights + with torch.no_grad(): + # the weigh to the linear layer is a transpose + # thus shard on col is equal to shard on row + sharded_weight = shard_colwise(module.weight.data, process_group) + linear_1d.weight.data.copy_(sharded_weight) + + if bias: + linear_1d.bias.copy_(module.bias.data) + + return linear_1d + + def chunk_weight(self): + self.weight_list = torch.chunk(self.weight, self.stream_chunk_num, dim=0) + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + fan_in, fan_out = self.in_features, self.out_features + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + if self.process_group is None: + src_rank = 0 + else: + src_rank = dist.distributed_c10d._get_global_rank(self.process_group, 0) + dist.broadcast(self.bias, src=src_rank, group=self.process_group) + + def forward(self, input_: Tensor) -> Tensor: + # Set up backprop all-reduce. + if self.parallel_input: + assert input_.shape[-1] == self.weight.shape[-1], \ + 'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format( + input_.shape, self.weight.shape, self.weight.shape[-1]) + input_ = input_ + else: + assert divide(input_.shape[-1], self.num_partitions) == self.weight.shape[-1], \ + 'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format( + input_.shape, self.weight.shape, self.weight.shape[-1] * self.num_partitions) + input_ = split_forward_gather_backward(input_, dim=-1, process_group=self.process_group) + + if self.stream_chunk_num > 1: + if self.training: + raise RuntimeError("use stream_chunk_num=1 in Linear1D_Row for training!") + with torch.no_grad(): + output_parallel_list = [None for i in range(self.stream_chunk_num)] + handle_list = [] + for i in range(self.stream_chunk_num): + output_parallel_list[i] = F.linear(input_, self.weight_list[i]) + handle = torch.distributed.all_reduce(output_parallel_list[i], + group=self.process_group, + async_op=True) + handle_list.append(handle) + # output_parallel_list[i] = reduce_input(output_parallel_list[i], ParallelMode.PARALLEL_1D) + for handle in handle_list: + handle.wait() + output = torch.cat(output_parallel_list, dim=-1) + else: + output_parallel = F.linear(input_, self.weight) + # output_parallel = linear_with_async_comm(input_, self.weight, None, ParallelMode.PARALLEL_1D, False) + output = reduce_input(output_parallel, self.process_group) + + if not self.skip_bias_add: + if self.bias is not None: + output = output + self.bias + return output + else: + return output, self.bias -# @LAYERS.register_module class LayerNorm1D(ColossalaiModule): r""" Layer Normalization for colossalai @@ -152,7 +432,6 @@ class LayerNorm1D(ColossalaiModule): super()._save_to_state_dict(destination, prefix, keep_vars) -# @LAYERS.register_module class Classifier1D(ParallelLayer): r"""RowLinear with given weight. Classifier of 1D parallelism. @@ -288,7 +567,6 @@ class Classifier1D(ParallelLayer): return output -# @LAYERS.register_module class VocabParallelClassifier1D(ParallelLayer): r"""ColLinear with given weight. Classifier of 1D parallelism. @@ -424,317 +702,8 @@ class VocabParallelClassifier1D(ParallelLayer): # @LAYERS.register_module -class Linear1D_Col(ParallelLayer): - r"""Linear layer with column parallelism. - - The linear layer is defined as :math:`Y = XA + b`. A is parallelized along - its second dimension as :math:`A = [A_1, ..., A_p]`. - - Args: - in_features (int): size of each input sample. - out_features (int): size of each output sample. - bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. - dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. - gather_output (bool, optional): If true, call all-gather on output and make Y available - to all GPUs, otherwise, every GPU will have its output - which is :math:`Y_i = XA_i`, defaults to False - skip_bias_add (bool, optional): If set to ``True``, it will skip bias add for linear layer, - which is preserved for kernel fusion, defaults to False - weight_initializer (:class:`typing.Callable`, optional): - The initializer of weight, defaults to kaiming uniform initializer. - bias_initializer (:class:`typing.Callable`, optional): - The initializer of bias, defaults to xavier uniform initializer. - - More details about ``initializer`` please refer to - `init `_. - """ - - def __init__(self, - in_features: int, - out_features: int, - bias: bool = True, - dtype: torch.dtype = None, - gather_output: bool = False, - skip_bias_add: bool = False, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): - super().__init__() - - # Keep input parameters - self.in_features = in_features - self.out_features = out_features - self.gather_output = gather_output - self.skip_bias_add = skip_bias_add - - if skip_bias_add and not bias: - raise ValueError('cannot skip bias addition if bias is None') - - # self.out_features_per_partition = divide(out_features*2, gpc.tensor_parallel_size) - self.out_features_per_partition = out_features - - # Parameters. - # Initialize weight. - factory_kwargs = {'device': get_current_device(), 'dtype': dtype} - self.weight = Parameter(torch.empty(self.out_features_per_partition, self.in_features, **factory_kwargs)) - - if bias: - self.bias = Parameter(torch.empty(self.out_features_per_partition, **factory_kwargs)) - else: - self.bias = None - with seed(ParallelMode.TENSOR): - self.reset_parameters(weight_initializer, bias_initializer) - self._set_tensor_parallel_attributes() - is_parallel_output = not self.gather_output - set_parallel_input(is_parallel_output) - - def reset_parameters(self, weight_initializer, bias_initializer) -> None: - fan_in, fan_out = self.in_features, self.out_features - weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) - if self.bias is not None: - bias_initializer(self.bias, fan_in=fan_in) - - def _set_tensor_parallel_attributes(self): - num_partition = gpc.get_world_size(ParallelMode.TENSOR) - set_tensor_parallel_attribute_by_partition(self.weight, num_partition) - if self.bias is not None: - set_tensor_parallel_attribute_by_partition(self.bias, num_partition) - - def _load_from_global_state_dict(self, state_dict, prefix, *args): - local_state = OrderedDict() - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - # weight - weight = state_dict.pop(weight_key, None) - if weight is not None: - local_state[weight_key] = weight - # bias - if self.bias is not None: - bias = state_dict.pop(bias_key, None) - if bias is not None: - local_state[bias_key] = bias - - local_state = partition_tensor_parallel_state_dict(local_state, - ParallelMode.PARALLEL_1D, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }) - super()._load_from_global_state_dict(local_state, prefix, *args) - - def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' - local_state = OrderedDict({weight_key: self.weight}) - if self.bias is not None: - local_state[bias_key] = self.bias - local_state = gather_tensor_parallel_state_dict(local_state, - ParallelMode.PARALLEL_1D, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, - keep_vars=keep_vars) - destination.update(local_state) - - def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: - assert input_.shape[-1] == self.weight.shape[-1], \ - 'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format( - input_.shape, self.weight.shape, self.weight.shape[-1]) - # Set up backprop all-reduce. - # input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D) - input_parallel = input_ - # Matrix multiply. - bias = self.bias if not self.skip_bias_add else None - # output_parallel = F.linear(input_parallel, self.weight, bias) - output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, ParallelMode.PARALLEL_1D, True) - if self.gather_output: - # All-gather across the partitions. - output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1) - else: - output = output_parallel - - if self.skip_bias_add: - return output, self.bias - else: - return output -# @LAYERS.register_module -class Linear1D_Row(ParallelLayer): - r""" Linear layer with row parallelism - - Args: - in_features (int): size of each input sample. - out_features (int): size of each output sample. - bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. - dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. - parallel_input (bool, optional): If set to ``True``, it's assumed that the input is split, defaults to False. - skip_bias_add (bool, optional): If set to ``True``, it will skip bias add for linear layer, - which is preserved for kernel fusion, defaults to False - weight_initializer (:class:`typing.Callable`, optional): - The initializer of weight, defaults to kaiming uniform initializer. - bias_initializer (:class:`typing.Callable`, optional): - The initializer of bias, defaults to xavier uniform initializer. - - More details about ``initializer`` please refer to - `init `_. - """ - - def __init__(self, - in_features: int, - out_features: int, - bias: bool = True, - dtype: torch.dtype = None, - parallel_input: bool = True, - skip_bias_add: bool = False, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), - stream_chunk_num: int = 1): - super().__init__() - - self.stream_chunk_num = stream_chunk_num - - # Keep input parameters - self.in_features = in_features - self.out_features = out_features - self.parallel_input = parallel_input - self.skip_bias_add = skip_bias_add - - if skip_bias_add and not bias: - raise ValueError('cannot skip bias addition if bias is None') - - # Divide the weight matrix along the last dimension. - # self.input_size_per_partition = divide(in_features*2, gpc.tensor_parallel_size) - self.input_size_per_partition = in_features - - # Parameters. - # Initialize weight. - factory_kwargs = {'device': get_current_device(), 'dtype': dtype} - self.weight = Parameter(torch.empty(self.out_features, self.input_size_per_partition, **factory_kwargs)) - - if self.stream_chunk_num > 1: - # TODO() work for inference only - self.chunk_weight() - if bias: - self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs)) - else: - self.bias = None - with seed(ParallelMode.TENSOR): - self.reset_parameters(weight_initializer, bias_initializer) - self._set_tensor_parallel_attributes() - set_parallel_input(False) - - def chunk_weight(self): - self.weight_list = torch.chunk(self.weight, self.stream_chunk_num, dim=0) - - def reset_parameters(self, weight_initializer, bias_initializer) -> None: - fan_in, fan_out = self.in_features, self.out_features - weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) - if self.bias is not None: - bias_initializer(self.bias, fan_in=fan_in) - broadcast(self.bias, gpc.get_ranks_in_group(ParallelMode.PARALLEL_1D)[0], ParallelMode.PARALLEL_1D) - - def _set_tensor_parallel_attributes(self): - num_partition = gpc.get_world_size(ParallelMode.TENSOR) - set_tensor_parallel_attribute_by_partition(self.weight, num_partition) - - def _load_from_global_state_dict(self, state_dict, prefix, *args): - local_state = OrderedDict() - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - # weight - weight = state_dict.pop(weight_key, None) - if weight is not None: - local_state[weight_key] = weight - # bias - if self.bias is not None: - bias = state_dict.pop(bias_key, None) - if bias is not None: - local_state[bias_key] = bias - - local_state = partition_tensor_parallel_state_dict(local_state, - ParallelMode.PARALLEL_1D, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: False - }) - super()._load_from_global_state_dict(local_state, prefix, *args) - - def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' - local_state = OrderedDict({weight_key: self.weight}) - if self.bias is not None: - local_state[bias_key] = self.bias - local_state = gather_tensor_parallel_state_dict(local_state, - ParallelMode.PARALLEL_1D, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: False - }, - keep_vars=keep_vars) - destination.update(local_state) - - def forward(self, input_: Tensor) -> Tensor: - # Set up backprop all-reduce. - if self.parallel_input: - assert input_.shape[-1] == self.weight.shape[-1], \ - 'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format( - input_.shape, self.weight.shape, self.weight.shape[-1]) - input_ = input_ - else: - assert divide(input_.shape[-1], gpc.tensor_parallel_size) == self.weight.shape[-1], \ - 'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format( - input_.shape, self.weight.shape, self.weight.shape[-1] * gpc.tensor_parallel_size) - input_ = split_forward_gather_backward(input_, ParallelMode.PARALLEL_1D, dim=-1) - - if self.stream_chunk_num > 1: - if self.training: - raise RuntimeError("use stream_chunk_num=1 in Linear1D_Row for training!") - with torch.no_grad(): - output_parallel_list = [None for i in range(self.stream_chunk_num)] - handle_list = [] - for i in range(self.stream_chunk_num): - output_parallel_list[i] = F.linear(input_, self.weight_list[i]) - handle = torch.distributed.all_reduce(output_parallel_list[i], - group=gpc.get_group(ParallelMode.PARALLEL_1D), - async_op=True) - handle_list.append(handle) - # output_parallel_list[i] = reduce_input(output_parallel_list[i], ParallelMode.PARALLEL_1D) - for handle in handle_list: - handle.wait() - output = torch.cat(output_parallel_list, dim=-1) - else: - output_parallel = F.linear(input_, self.weight) - # output_parallel = linear_with_async_comm(input_, self.weight, None, ParallelMode.PARALLEL_1D, False) - output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D) - if not self.skip_bias_add: - if self.bias is not None: - output = output + self.bias - return output - else: - return output, self.bias - - -# @LAYERS.register_module class Embedding1D(ParallelLayer): r"""Embedding for 1D parallelism. @@ -842,7 +811,6 @@ class Embedding1D(ParallelLayer): return output -# @LAYERS.register_module class VocabParallelEmbedding1D(ParallelLayer): r"""Embedding parallelized in the vocabulary dimension. @@ -960,7 +928,6 @@ class VocabParallelEmbedding1D(ParallelLayer): return output -# @LAYERS.register_module class Dropout1D(ParallelLayer): """Dropout layer of 1D parallelism. diff --git a/colossalai/shardformer/layer/utils.py b/colossalai/shardformer/layer/utils.py new file mode 100644 index 000000000..c3d6ab57e --- /dev/null +++ b/colossalai/shardformer/layer/utils.py @@ -0,0 +1,138 @@ +from contextlib import contextmanager + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup + + +class Randomizer: + """ + Randomizer enables the program to be executed under a different seed within the context. + + Example: + + ```python + randomizer = Randomizer(seed=1024) + + with randomizer.fork(): + # do something here with seed 1024 + do_something() + ``` + + Args: + seed (int): The random seed to set. + enable_cpu (bool): fork the CPU RNG state as well. + with_index (bool): whether to use the index of the randomizer. + """ + + _INDEX = 0 + + def __init__(self, seed: int): + # TODO: remove colossalai.context.random + + self.seed = seed + + # Handle CUDA rng state + # 1. get the current rng state + # 2. set the seed and store the rng state + # 3. recover the original rng state + cuda_original_rng_state = torch.cuda.get_rng_state() + torch.cuda.manual_seed(seed) + self.cuda_rng_state = torch.cuda.get_rng_state() + torch.cuda.set_rng_state(cuda_original_rng_state) + + # to the same for cpu rng state + cpu_original_rng_state = torch.get_rng_state() + torch.manual_seed(seed) + self.cpu_rng_state = torch.get_rng_state() + torch.set_rng_state(cpu_original_rng_state) + + def _set_cuda_rng_state(self, rng_state): + torch.cuda.set_rng_state(rng_state) + + def _get_cuda_rng_state(self): + current_state = torch.cuda.get_rng_state() + return current_state + + def _set_cpu_rng_state(self, rng_state): + torch.set_rng_state(rng_state) + + def _get_cpu_rng_state(self): + current_state = torch.get_rng_state() + return current_state + + @contextmanager + def fork_rng(self, enable_cpu: bool = False): + """ + This is a context manager to change the dropout state and recover the original state. + + Usage: + :: + >>> with _seed_manager.dropout_mode(): + >>> input = super().forward(input) + """ + try: + current_cuda_rng_state = self._get_cuda_rng_state() + self._set_cuda_rng_state(self.cuda_rng_state) + + if enable_cpu: + current_cpu_rng_state = self._get_cpu_rng_state() + self._set_cpu_rng_state(self.cpu_rng_state) + yield + finally: + self.cuda_rng_state = self._get_cuda_rng_state() + self._set_cuda_rng_state(current_cuda_rng_state) + + if enable_cpu: + self.cpu_rng_state = self._get_cpu_rng_state() + self._set_cpu_rng_state(current_cpu_rng_state) + + @staticmethod + def index(): + """ + Return the index of the randomizer. The index is useful when the user wants + to introduce some randomness in the program. + + Note: + The index will increment by one each time this method is called. + + Example: + + ```python + # assume we need a randomizer to init the weight of different layers + # we can use the index of the randomizer to do so that + # each layer has its own randomizer with a different seed + base_seed = torch.random.initial_seed() + seed = base_seed + Randomizer.index() + randomizer = Randomizer(seed) + + with randomizer.fork(): + init_weights() + ``` + + """ + idx = Randomizer._INDEX + Randomizer._INDEX += 1 + return idx + + +def create_randomizer_with_offset(seed: int, process_group: ProcessGroup = None): + """ + Create a randomizer with an offset. The offset is equal to the rank of the process and the index of the randomizer. + + Args: + seed (int): The base random seed to set. + enable_cpu (bool): fork the CPU RNG state as well. + process_group (ProcessGroup): the process group to get the rank from. + + Returns: + Randomizer: the randomizer with offset. + """ + offset = Randomizer.index() + + if dist.is_initialized(): + rank = dist.get_rank(process_group) + offset += rank + + seed += offset + return Randomizer(seed=seed) diff --git a/colossalai/tensor/d_tensor/api.py b/colossalai/tensor/d_tensor/api.py new file mode 100644 index 000000000..afb1fc003 --- /dev/null +++ b/colossalai/tensor/d_tensor/api.py @@ -0,0 +1,44 @@ +from typing import Union + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup + +from colossalai.device.device_mesh import DeviceMesh + +from .d_tensor import DTensor +from .sharding_spec import ShardingSpec + + +def shard_rowwise(tensor: torch.Tensor, group_or_device_mesh: Union[ProcessGroup, DeviceMesh] = None) -> DTensor: + """ + Shard the first dim of the given tensor + """ + # if the group_or_device_mesh is None, we shard the tensor with respect to the global process group + if group_or_device_mesh is None: + group_or_device_mesh = dist.GroupMember.WORLD + + if isinstance(group_or_device_mesh, ProcessGroup): + device_mesh = DeviceMesh.from_process_group(group_or_device_mesh) + else: + assert len(group_or_device_mesh.shape) == 1, 'Only 1D DeviceMesh is accepted for row-wise sharding.' + device_mesh = group_or_device_mesh + sharding_spec = ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={0: [0]}) + return DTensor(tensor, device_mesh, sharding_spec) + + +def shard_colwise(tensor: torch.Tensor, group_or_device_mesh: Union[ProcessGroup, DeviceMesh] = None) -> DTensor: + """ + Shard the first dim of the given tensor + """ + # if the group_or_device_mesh is None, we shard the tensor with respect to the global process group + if group_or_device_mesh is None: + group_or_device_mesh = dist.GroupMember.WORLD + + if isinstance(group_or_device_mesh, ProcessGroup): + device_mesh = DeviceMesh.from_process_group(group_or_device_mesh) + else: + assert len(group_or_device_mesh.shape) == 1, 'Only 1D DeviceMesh is accepted for row-wise sharding.' + device_mesh = group_or_device_mesh + sharding_spec = ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={-1: [0]}) + return DTensor(tensor, device_mesh, sharding_spec) diff --git a/colossalai/tensor/d_tensor/layout.py b/colossalai/tensor/d_tensor/layout.py index ee7ef74a9..f15956ea3 100644 --- a/colossalai/tensor/d_tensor/layout.py +++ b/colossalai/tensor/d_tensor/layout.py @@ -34,7 +34,7 @@ class Layout: def get_sharded_shape_per_device(self): sharded_shape = list(self.entire_shape) for dim, shard_list in self.sharding_spec.dim_partition_dict.items(): - mesh_list = [self.device_mesh.mesh_shape[mesh_dim] for mesh_dim in shard_list] + mesh_list = [self.device_mesh.shape[mesh_dim] for mesh_dim in shard_list] shard_partitions = reduce(operator.mul, mesh_list, 1) assert sharded_shape[ dim] % shard_partitions == 0, f'Cannot shard dimension {dim} into {shard_partitions} partitions.' @@ -45,14 +45,15 @@ class Layout: sharding_spec = self.sharding_spec # make sure all axes in logical device mesh only be used once - dim_check_list = list(range(self.device_mesh.logical_mesh_id.dim())) - for dim, shard_list in sharding_spec.dim_partition_dict.items(): - for element in shard_list: - if element in dim_check_list: - dim_check_list.remove(element) - else: - raise DuplicatedShardingDimensionError( - f"find an invalid sharding axis {element} in dim_partition_dict in tensor dimension {dim}.") + if self.device_mesh.logical_mesh_id is not None: + dim_check_list = list(range(self.device_mesh.logical_mesh_id.dim())) + for dim, shard_list in sharding_spec.dim_partition_dict.items(): + for element in shard_list: + if element in dim_check_list: + dim_check_list.remove(element) + else: + raise DuplicatedShardingDimensionError( + f"find an invalid sharding axis {element} in dim_partition_dict in tensor dimension {dim}.") # make sure that the sharding for a dimension is divisible by the number of devices for dim, shard_list in sharding_spec.dim_partition_dict.items(): @@ -60,7 +61,7 @@ class Layout: num_devices = 1 for element in shard_list: - num_devices *= self.device_mesh.mesh_shape[element] + num_devices *= self.device_mesh.shape[element] if tensor_dim_size % num_devices != 0: raise ShardingNotDivisibleError( diff --git a/colossalai/tensor/d_tensor/layout_converter.py b/colossalai/tensor/d_tensor/layout_converter.py index cf02aac30..abc70e19a 100644 --- a/colossalai/tensor/d_tensor/layout_converter.py +++ b/colossalai/tensor/d_tensor/layout_converter.py @@ -304,7 +304,7 @@ class LayoutConverter(metaclass=SingletonMeta): process_groups_dict = source_layout.device_mesh.process_groups_dict # legal sharding dims means the mesh_id is still available to use. - legal_sharding_dims = [i for i in range(len(source_layout.device_mesh.mesh_shape))] + legal_sharding_dims = [i for i in range(len(source_layout.device_mesh.shape))] for dim, shard_list in source_spec.dim_partition_dict.items(): for element in shard_list: legal_sharding_dims.remove(element) diff --git a/tests/test_shardformer/test_layer/test_linear_1d.py b/tests/test_shardformer/test_layer/test_linear_1d.py new file mode 100644 index 000000000..449522c64 --- /dev/null +++ b/tests/test_shardformer/test_layer/test_linear_1d.py @@ -0,0 +1,67 @@ +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.testing import assert_close + +import colossalai +from colossalai.shardformer.layer.layers import Linear1D_Col, Linear1D_Row +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn + + +def check_linear_1d_col(): + linear = nn.Linear(32, 128).cuda() + linear_col = Linear1D_Col.from_native_module(linear, process_group=None, gather_output=True) + + assert linear_col.weight.shape == torch.Size([64, 32]) + assert linear_col.bias.shape == torch.Size([64]) + + # check computation correctness + x = torch.rand(4, 32).cuda() + out = linear(x) + gather_out = linear_col(x) + assert_close(out, gather_out) + + # check backward correctness + out.sum().backward() + gather_out.sum().backward() + + rank = dist.get_rank() + target_grad = torch.chunk(linear.weight.grad, 2, dim=0)[rank] + assert_close(target_grad, linear_col.weight.grad) + + +def check_linear_1d_row(): + linear = nn.Linear(32, 128).cuda() + linear_row = Linear1D_Row.from_native_module(linear, process_group=None, parallel_input=False) + + assert linear_row.weight.shape == torch.Size([128, 16]) + assert linear_row.bias.shape == torch.Size([128]) + + # check computation correctness + x = torch.rand(4, 32).cuda() + out = linear(x) + gather_out = linear_row(x) + assert_close(out, gather_out) + + # check backward correctness + out.sum().backward() + gather_out.sum().backward() + + rank = dist.get_rank() + target_grad = torch.chunk(linear.weight.grad, 2, dim=1)[rank] + assert_close(target_grad, linear_row.weight.grad) + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + check_linear_1d_col() + check_linear_1d_row() + + +@rerun_if_address_is_in_use() +def test_linear(): + spawn(run_dist, nprocs=2) + + +if __name__ == '__main__': + test_linear()