From 424629fea023a83aa84eacf55afc8007314d9f54 Mon Sep 17 00:00:00 2001 From: Bin Jia <45593998+FoolPlayer@users.noreply.github.com> Date: Wed, 16 Aug 2023 15:41:20 +0800 Subject: [PATCH 01/33] [shardformer/sequence parallel] Cherry pick commit to new branch (#4450) * [shardformer/sequence parallel] Support sequence parallel for gpt2 (#4384) * [sequence parallel] add sequence parallel linear col/row support (#4336) * add sequence parallel linear col/row support * add annotation * add annotation * add support for gpt2 fused qkv linear layer * support sequence parallel in GPT2 * add docstring and note * add requirments * remove unused flash-attb * modify flash attn test * modify flash attn setting * modify flash attn code * add assert before divide, rename forward function * [shardformer/test] fix gpt2 test with seq-parallel * [shardformer/sequence parallel] Overlap input gather and grad computation during col backward (#4401) * overlap gather input / grad computing during col backward * modify test for overlap * simplify code * fix code and modify cuda stream synchronize * [shardformer/sequence parallel] polish code --- .../booster/plugin/hybrid_parallel_plugin.py | 5 +- colossalai/shardformer/layer/_operation.py | 276 +++++++++++++++++- colossalai/shardformer/layer/linear.py | 23 +- .../shardformer/layer/qkv_fused_linear.py | 27 +- colossalai/shardformer/modeling/gpt2_seq.py | 222 ++++++++++++++ .../shardformer/policies/base_policy.py | 26 +- colossalai/shardformer/policies/gpt2.py | 9 + colossalai/shardformer/shard/shard_config.py | 1 + .../test_gpt2_qkv_fused_linear_1d.py | 34 ++- .../test_layer/test_linear_1d.py | 75 +++-- tests/test_shardformer/test_model/_utils.py | 15 +- .../test_model/test_shard_gpt2.py | 7 + 12 files changed, 655 insertions(+), 65 deletions(-) create mode 100644 colossalai/shardformer/modeling/gpt2_seq.py diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 28a19af0c..3d45a9112 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -152,6 +152,7 @@ class HybridParallelPlugin(PipelinePluginBase): enable_fused_normalization: bool = False, enable_flash_attention: bool = False, enable_jit_fused: bool = False, + enable_sequence_parallelism: bool = False, num_microbatches: Optional[int] = None, initial_scale: float = 2**16, min_scale: float = 1, @@ -178,6 +179,7 @@ class HybridParallelPlugin(PipelinePluginBase): self.enable_fused_normalization = enable_fused_normalization self.enable_flash_attention = enable_flash_attention self.enable_jit_fused = enable_jit_fused + self.enable_sequence_parallelism = enable_sequence_parallelism self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size) self.stage_manager = None self.schedule = None @@ -195,7 +197,8 @@ class HybridParallelPlugin(PipelinePluginBase): enable_all_optimization=self.enable_all_optimization, enable_fused_normalization=self.enable_fused_normalization, enable_flash_attention=self.enable_flash_attention, - enable_jit_fused=self.enable_jit_fused) + enable_jit_fused=self.enable_jit_fused, + enable_sequence_parallelism=enable_sequence_parallelism) self.amp_config = dict( initial_scale=initial_scale, growth_factor=growth_factor, diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index 7e97bee01..13e563123 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -1,3 +1,5 @@ +from typing import Any + import torch import torch.distributed as dist import torch.nn.functional as F @@ -141,6 +143,215 @@ class LinearWithAsyncCommunication(torch.autograd.Function): return grad_input, grad_weight, grad_bias, None, None, None +class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function): + """Gather input from sequence parallel in forward and reduce-scatter gradient in backward + + Args: + input_ (`torch.Tensor`): The input tensor from sequence parallel region. + process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication. + overlap (`bool`): Whther to overlap the all_gather op and gradient calculate in backward. + + """ + + @staticmethod + def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap): + ctx.save_for_backward(input_, weight) + ctx.use_bias = bias is not None + ctx.process_group = process_group + ctx.async_grad_reduce_scatter = async_grad_reduce_scatter + ctx.dim = dim + ctx.overlap = overlap + + input_parallel = _gather(input_, dim, process_group) + + if bias is not None: + output = F.linear(input_parallel, weight, bias) + else: + output = F.linear(input_parallel, weight) + + return output + + @staticmethod + def backward(ctx, grad_output): + input_, weight = ctx.saved_tensors + use_bias = ctx.use_bias + dim = ctx.dim + process_group = ctx.process_group + overlap = ctx.overlap + + if not overlap: + # TODO: overlap SP input with gradient computation + input_parallel = _gather(input_, dim, process_group) + + total_input = input_parallel + grad_input = grad_output.matmul(weight) + grad_output = grad_output.contiguous() + # Convert the tensor shapes to 2D for execution compatibility + if len(grad_output.shape) > 2: + grad_output = grad_output.view(-1, grad_output.shape[-1]) + total_input = total_input.view(-1, total_input.shape[-1]) + + # TODO: overlap SP input with gradient computation + if ctx.async_grad_reduce_scatter: + # Asynchronous reduce-scatter + input_list = [ + item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim) + ] + output = torch.empty(input_.shape, dtype=input_parallel.dtype, + device=input_parallel.device).contiguous() + handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True) + # Delay the start of weight gradient computation shortly (3us) to have + # reduce-scatter scheduled first and have GPU resources allocated + _ = torch.empty(1, device=grad_output.device) + 1 + + grad_weight = grad_output.t().matmul(total_input) + grad_bias = grad_output.sum(dim=0) if use_bias else None + + if ctx.async_grad_reduce_scatter: + handle.wait() + + else: + # create new stream for calculate the gradient + calculate_stream = torch.cuda.Stream() + + # do all gather in default stream + input_ = input_.contiguous() + world_size = dist.get_world_size(process_group) + rank = dist.get_rank(process_group) + tensor_list = [torch.empty_like(input_) for _ in range(world_size)] + tensor_list[rank] = input_ + gather_handle = dist.all_gather(tensor_list, input_, group=process_group, async_op=True) + + # calculate gradient in calculate_stream + with torch.cuda.stream(calculate_stream): + # calculate + grad_input = grad_output.matmul(weight) + grad_output = grad_output.contiguous() + # Convert the tensor shapes to 2D for execution compatibility + if len(grad_output.shape) > 2: + grad_output = grad_output.view(-1, grad_output.shape[-1]) + grad_bias = grad_output.sum(dim=0) if use_bias else None + + # prepare data + input_list = [ + item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim) + ] + output = torch.empty(input_.shape, dtype=input_.dtype, device=input_.device).contiguous() + + torch.cuda.current_stream().wait_stream(calculate_stream) + + reducescatter_handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True) + with torch.cuda.stream(calculate_stream): + input_parallel = torch.cat(tensor_list, dim=dim).contiguous() + if len(input_parallel.shape) > 2: + input_parallel = input_parallel.view(-1, input_parallel.shape[-1]) + print(grad_output.shape, input_parallel.shape) + grad_weight = grad_output.t().matmul(input_parallel) + + torch.cuda.current_stream().wait_stream(calculate_stream) + + return output, grad_weight, grad_bias, None, None, None, None + + +class _LinearWithReduceScatterForwardGatherBackward(torch.autograd.Function): + """Gather input from sequence parallel in forward and reduce-scatter gradient in backward + + Args: + input_ (`torch.Tensor`): The input tensor from sequence parallel region. + process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication. + + """ + + @staticmethod + def forward(ctx, input_, process_group, dim): + ctx.dim = dim + ctx.process_group = process_group + + # do reduce-scatter + new_shape = list(input_.shape) + assert new_shape[dim] % dist.get_world_size(process_group) == 0, \ + f'The dimension to split ({new_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). ' + new_shape[dim] = new_shape[dim] // dist.get_world_size(process_group) + input_list = [item.contiguous() for item in torch.chunk(input_, dist.get_world_size(process_group), dim=dim)] + output = torch.empty(new_shape, dtype=input_.dtype, device=input_.device) + dist.reduce_scatter(output, input_list, group=process_group) + + return output + + @staticmethod + def backward(ctx, grad_output): + dim = ctx.dim + process_group = ctx.process_group + + return _gather(grad_output, dim, process_group), None, None + + +class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function): + """ + This class is designed for matmul operation with gather forward and reduce-scatter backward. + + Args: + input_ (`torch.Tensor`): input matrix. + dim (int): the dimension to perform split and gather + process_group (`torch.distributed.ProcessGroup`): the process group used for collective communication + + """ + + @staticmethod + def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim): + ctx.save_for_backward(input_, weight) + ctx.use_bias = bias is not None + ctx.process_group = process_group + ctx.async_grad_reduce_scatter = async_grad_reduce_scatter + ctx.dim = dim + + input_parallel = _gather(input_, dim, process_group) + + output = torch.matmul(input_parallel, weight) + + if bias is not None: + output = output + bias + return output + + @staticmethod + def backward(ctx, grad_output): + input_, weight = ctx.saved_tensors + use_bias = ctx.use_bias + dim = ctx.dim + process_group = ctx.process_group + + # TODO: overlap SP input with gradient computation + input_parallel = _gather(input_, dim, process_group) + + total_input = input_parallel + grad_input = grad_output.matmul(weight.T) + grad_output = grad_output.contiguous() + # Convert the tensor shapes to 2D for execution compatibility + if len(grad_output.shape) > 2: + grad_output = grad_output.view(-1, grad_output.shape[-1]) + total_input = total_input.view(-1, total_input.shape[-1]) + + # TODO: overlap SP input with gradient computation + if ctx.async_grad_reduce_scatter: + # Asynchronous reduce-scatter + input_list = [ + item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim) + ] + output = torch.empty(input_.shape, dtype=input_parallel.dtype, device=input_parallel.device).contiguous() + handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True) + # Delay the start of weight gradient computation shortly (3us) to have + # reduce-scatter scheduled first and have GPU resources allocated + _ = torch.empty(1, device=grad_output.device) + 1 + + grad_weight = total_input.t().matmul(grad_output) + grad_bias = grad_output.sum(dim=0) if use_bias else None + + if ctx.async_grad_reduce_scatter: + handle.wait() + + return output, grad_weight, grad_bias, None, None, None + + class _SplitForwardGatherBackward(torch.autograd.Function): """ Split the input and keep only the corresponding chuck to the rank. @@ -200,6 +411,26 @@ class _ReduceBackward(torch.autograd.Function): return _reduce(grad_output, ctx.process_group), None +class _GatherForwardSplitBackward(torch.autograd.Function): + """Gather the input from model parallel region and concatenate. + + Args: + input_: input matrix. + parallel_mode: parallel mode. + dim: dimension + """ + + @staticmethod + def forward(ctx, input_, dim, process_group): + ctx.process_group = process_group + ctx.dim = dim + return _gather(input_, dim, process_group) + + @staticmethod + def backward(ctx, grad_output): + return _split(grad_output, ctx.dim, ctx.process_group), None, None + + def _reduce(input_, process_group): # skip if only one rank involved if dist.get_world_size(process_group) == 1: @@ -235,6 +466,7 @@ def _gather(input_, dim=-1, process_group=None): return input_ # all gather + input_ = input_.contiguous() rank = dist.get_rank(process_group) tensor_list = [torch.empty_like(input_) for _ in range(world_size)] tensor_list[rank] = input_ @@ -246,24 +478,27 @@ def _gather(input_, dim=-1, process_group=None): return output -class _GatherForwardSplitBackward(torch.autograd.Function): - """Gather the input from model parallel region and concatenate. +def _reduce_scatter(input_, dim=1, process_group=None): + """ Do reduce-scatter operation. Args: - input_: input matrix. - parallel_mode: parallel mode. - dim: dimension + input_ (`torch.Tensor`): The input tensor from sequence parallel region. + dim (int): The dimension to perform reduce-scatter. + process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication. """ + world_size = dist.get_world_size(process_group) + if world_size == 1: + return input_ - @staticmethod - def forward(ctx, input_, dim, process_group): - ctx.process_group = process_group - ctx.dim = dim - return _gather(input_, dim, process_group) + # reduce-scatter + new_shape = list(input_.shape) + assert new_shape[dim] % dist.get_world_size(process_group) == 0, \ + f'The dimension to split ({new_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). ' + new_shape[dim] = new_shape[dim] // world_size + output = torch.empty(new_shape, dtype=input_.dtype, device=input_.device) + dist.reduce_scatter(output, input_, group=process_group) - @staticmethod - def backward(ctx, grad_output): - return _split(grad_output, ctx.dim, ctx.process_group), None, None + return output def matmul_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce): @@ -274,6 +509,21 @@ def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allre return LinearWithAsyncCommunication.apply(input_, weight, bias, process_group, async_grad_allreduce) +def linear_gather_forward_reducescatter_backward(input_, weight, bias, process_group, async_grad_reduce_scatter, dim, + overlap): + return _LinearWithGatherForwardReduceScatterBackward.apply(input_, weight, bias, process_group, + async_grad_reduce_scatter, dim, overlap) + + +def linear_reducescatter_forward_gather_backward(input_, process_group, dim): + return _LinearWithReduceScatterForwardGatherBackward.apply(input_, process_group, dim) + + +def matmul_gather_forward_reducescatter_backward(input_, weight, bias, process_group, async_grad_reduce_scatter, dim): + return _MatmulWithGatherForwardReduceScatterBackward.apply(input_, weight, bias, process_group, + async_grad_reduce_scatter, dim) + + def gather_forward_split_backward(input_, dim, process_group): return _GatherForwardSplitBackward.apply(input_, dim, process_group) diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py index d59b68ce4..69ac3ad25 100644 --- a/colossalai/shardformer/layer/linear.py +++ b/colossalai/shardformer/layer/linear.py @@ -24,6 +24,8 @@ from colossalai.tensor.d_tensor.api import ( from ._operation import ( gather_forward_split_backward, + linear_gather_forward_reducescatter_backward, + linear_reducescatter_forward_gather_backward, linear_with_async_comm, reduce_forward, split_forward_gather_backward, @@ -50,6 +52,8 @@ class Linear1D_Col(ParallelModule): gather_output (bool, optional): If true, call all-gather on output and make Y available to all GPUs, otherwise, every GPU will have its output which is :math:`Y_i = XA_i`, defaults to False + seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False. + overlap (`bool`): If set to ``True``, it will overlap input all-gather with gradient computation during backward, defaults to False. skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, which is preserved for kernel fusion, defaults to False weight_initializer (`typing.Callable`): @@ -69,6 +73,8 @@ class Linear1D_Col(ParallelModule): device: torch.device = None, process_group: ProcessGroup = None, gather_output: bool = False, + seq_parallel: bool = False, + overlap: bool = False, skip_bias_add: bool = False, weight: Optional[Parameter] = None, bias_: Optional[Parameter] = None, @@ -80,6 +86,8 @@ class Linear1D_Col(ParallelModule): self.in_features = in_features self.out_features = out_features self.gather_output = gather_output + self.seq_parallel = seq_parallel + self.overlap = overlap self.skip_bias_add = skip_bias_add self.device = device self.process_group = process_group @@ -180,7 +188,11 @@ class Linear1D_Col(ParallelModule): # Matrix multiply. bias = self.bias if not self.skip_bias_add else None - output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True) + if self.seq_parallel: + output_parallel = linear_gather_forward_reducescatter_backward(input_parallel, self.weight, bias, + self.process_group, True, 1, self.overlap) + else: + output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True) if self.gather_output: # All-gather across the partitions. @@ -203,6 +215,8 @@ class Linear1D_Row(ParallelModule): bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. dtype (`torch.dtype`): The dtype of parameters, defaults to None. parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False. + process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None. + seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False. skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, which is preserved for kernel fusion, defaults to False weight_initializer (:class:`typing.Callable`, optional): @@ -221,6 +235,7 @@ class Linear1D_Row(ParallelModule): dtype: torch.dtype = None, device: torch.device = None, process_group: ProcessGroup = None, + seq_parallel: bool = False, parallel_input: bool = True, skip_bias_add: bool = False, weight: Optional[Parameter] = None, @@ -238,6 +253,7 @@ class Linear1D_Row(ParallelModule): self.parallel_input = parallel_input self.skip_bias_add = skip_bias_add self.process_group = process_group + self.seq_parallel = seq_parallel self.num_partitions = dist.get_world_size(self.process_group) if skip_bias_add and not bias: @@ -373,7 +389,10 @@ class Linear1D_Row(ParallelModule): output = torch.cat(output_parallel_list, dim=-1) else: output_parallel = F.linear(input_, self.weight) - output = reduce_forward(output_parallel, self.process_group) + if self.seq_parallel: + output = linear_reducescatter_forward_gather_backward(output_parallel, self.process_group, 1) + else: + output = reduce_forward(output_parallel, self.process_group) if not self.skip_bias_add: if self.bias is not None: diff --git a/colossalai/shardformer/layer/qkv_fused_linear.py b/colossalai/shardformer/layer/qkv_fused_linear.py index df942d43e..ccb2bf7ea 100644 --- a/colossalai/shardformer/layer/qkv_fused_linear.py +++ b/colossalai/shardformer/layer/qkv_fused_linear.py @@ -25,7 +25,9 @@ from colossalai.tensor.d_tensor.api import ( from ._operation import ( gather_forward_split_backward, + linear_reducescatter_forward_gather_backward, linear_with_async_comm, + matmul_gather_forward_reducescatter_backward, matmul_with_async_comm, reduce_backward, reduce_forward, @@ -150,6 +152,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule): device (`torch.device`): The device of parameters, defaults to None. n_fused (int): The number items fused, defaults to 3 (QKV). process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None. + seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False. gather_output (bool, optional): If true, call all-gather on output and make Y available to all GPUs, otherwise, every GPU will have its output which is :math:`Y_i = XA_i`, defaults to False @@ -173,6 +176,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule): process_group: ProcessGroup = None, async_communication: bool = False, gather_output: bool = False, + seq_parallel: bool = False, skip_bias_add: bool = False, n_fused: int = 3, weight: Optional[Parameter] = None, @@ -185,6 +189,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule): self.in_features = in_features self.out_features = out_features self.gather_output = gather_output + self.seq_parallel = seq_parallel self.skip_bias_add = skip_bias_add self.device = device self.n_fused = n_fused @@ -296,15 +301,19 @@ class GPT2FusedLinearConv1D_Col(ParallelModule): assert input_.shape[-1] == self.weight.shape[0], \ 'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format( input_.shape, self.weight.shape, self.weight.shape[-1]) - # Set up backprop all-reduce. - input_parallel = reduce_backward(input_, self.process_group) - # input_parallel = input_ # Matrix multiply. bias = self.bias if not self.skip_bias_add else None - output_parallel = matmul_with_async_comm(input_parallel, self.weight, bias, self.process_group, - self.async_communication) + if self.seq_parallel: + input_parallel = input_ + output_parallel = matmul_gather_forward_reducescatter_backward(input_parallel, self.weight, bias, + self.process_group, True, 1) + else: + # Set up backprop all-reduce. + input_parallel = reduce_backward(input_, self.process_group) + output_parallel = matmul_with_async_comm(input_parallel, self.weight, bias, self.process_group, + self.async_communication) if self.gather_output: # All-gather across the partitions. @@ -329,6 +338,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule): dtype (`torch.dtype`): The dtype of parameters, defaults to None. parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False. skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, + seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False. which is preserved for kernel fusion, defaults to False weight_initializer (:class:`typing.Callable`, optional): The initializer of weight, defaults to kaiming uniform initializer. @@ -346,6 +356,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule): dtype: torch.dtype = None, device: torch.device = None, process_group: ProcessGroup = None, + seq_parallel: bool = False, parallel_input: bool = True, skip_bias_add: bool = False, weight: Optional[Parameter] = None, @@ -363,6 +374,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule): self.parallel_input = parallel_input self.skip_bias_add = skip_bias_add self.process_group = process_group + self.seq_parallel = seq_parallel self.num_partitions = dist.get_world_size(self.process_group) if skip_bias_add and not bias: @@ -499,7 +511,10 @@ class GPT2FusedLinearConv1D_Row(ParallelModule): output = torch.cat(output_parallel_list, dim=-1) else: output_parallel = torch.matmul(input_, self.weight) - output = reduce_forward(output_parallel, self.process_group) + if self.seq_parallel: + output = linear_reducescatter_forward_gather_backward(output_parallel, self.process_group, 1) + else: + output = reduce_forward(output_parallel, self.process_group) if not self.skip_bias_add: if self.bias is not None: diff --git a/colossalai/shardformer/modeling/gpt2_seq.py b/colossalai/shardformer/modeling/gpt2_seq.py new file mode 100644 index 000000000..a6da96e7b --- /dev/null +++ b/colossalai/shardformer/modeling/gpt2_seq.py @@ -0,0 +1,222 @@ +# this code is modified from transformers.models.gpt2.modeling_gpt2 +# https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/models/gpt2/modeling_gpt2.py#L670 + +from typing import Optional, Tuple, Union + +import torch +import torch.distributed as dist +from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions +from transformers.utils import logging + +from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward +from colossalai.shardformer.shard import ShardConfig + +logger = logging.get_logger(__name__) + + +# TODO: put all contents in `gpt2.py` and make it compatible with pipeline +def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig): + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + batch_size = input_ids.shape[0] + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size = inputs_embeds.shape[0] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, input_shape[-1]) + if position_ids is not None: + position_ids = position_ids.view(-1, input_shape[-1]) + + if past_key_values is None: + past_length = 0 + past_key_values = tuple([None] * len(self.h)) + else: + past_length = past_key_values[0][0].size(-2) + if position_ids is None: + position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) + + # GPT2Attention mask. + if attention_mask is not None: + if batch_size <= 0: + raise ValueError("batch_size has to be defined and > 0") + attention_mask = attention_mask.view(batch_size, -1) + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + attention_mask = attention_mask[:, None, None, :] + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and the dtype's smallest value for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.add_cross_attention and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # head_mask has shape n_layer x batch x n_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) + position_embeds = self.wpe(position_ids) + hidden_states = inputs_embeds + position_embeds + + if token_type_ids is not None: + token_type_embeds = self.wte(token_type_ids) + hidden_states = hidden_states + token_type_embeds + + hidden_states = self.drop(hidden_states) + + output_shape = input_shape + (hidden_states.size(-1),) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + all_hidden_states = () if output_hidden_states else None + + # split the input tensor along sequence dimension + # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] + hidden_states = split_forward_gather_backward(hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group) + + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + # Model parallel + if self.model_parallel: + torch.cuda.set_device(hidden_states.device) + # Ensure layer_past is on same device as hidden_states (might not be correct) + if layer_past is not None: + layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) + # Ensure that attention_mask is always on the same device as hidden_states + if attention_mask is not None: + attention_mask = attention_mask.to(hidden_states.device) + if isinstance(head_mask, torch.Tensor): + head_mask = head_mask.to(hidden_states.device) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, use_cache, output_attentions) + + return custom_forward + + outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + None, + attention_mask, + head_mask[i], + encoder_hidden_states, + encoder_attention_mask, + ) + else: + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask[i], + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) + + # Model Parallel: If it's the last layer for that device, put things on the next device + if self.model_parallel: + for k, v in self.device_map.items(): + if i == v[-1] and "cuda:" + str(k) != self.last_device: + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + + # When sequence parallelism done, gather the output tensor in forward and split it in backward + hidden_states = gather_forward_split_backward(hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group) + + hidden_states = self.ln_f(hidden_states) + hidden_states = hidden_states.view(output_shape) + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] + if v is not None) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + return forward diff --git a/colossalai/shardformer/policies/base_policy.py b/colossalai/shardformer/policies/base_policy.py index 69493bfb6..7022a1cfd 100644 --- a/colossalai/shardformer/policies/base_policy.py +++ b/colossalai/shardformer/policies/base_policy.py @@ -11,17 +11,12 @@ from torch.nn import Module from colossalai.pipeline.stage_manager import PipelineStageManager +from ..layer.parallel_module import ParallelModule from ..shard.shard_config import ShardConfig __all__ = ["ParallelModule", "SubModuleReplacementDescription", "ModulePolicyDescription", "Policy"] -class ParallelModule(): - - def __init__(self): - pass - - @dataclass class SubModuleReplacementDescription: r""" @@ -231,3 +226,22 @@ class Policy(ABC): end_idx = num_layers_per_stage_accumulated[stage + 1] return [start_idx, end_idx] + + def append_seq_parallel_to_policy( + self, + suffix_list: List[str], + module_policy_description: ModulePolicyDescription, + ): + r""" + Append the sequence parallel policy to the policy for the given key. + + Args: + suffix_list (List[str]): the suffix list of the module to be parallelized + policy (Dict[Union[str, nn.Module], ModulePolicyDescription]): the policy to be updated + """ + + for sub_description in module_policy_description.sub_module_replacement: + if (sub_description.suffix in suffix_list): + if sub_description.kwargs is None: + sub_description.kwargs = {} + sub_description.kwargs["seq_parallel"] = True diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 20e5fa372..276d95660 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -7,6 +7,7 @@ import colossalai.shardformer.layer as col_nn from .._utils import getattr_, setattr_ from ..modeling.gpt2 import GPT2PipelineForwards, get_gpt2_flash_attention_forward +from ..modeling.gpt2_seq import gpt2_sequence_parallel_forward_fn from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = [ @@ -49,6 +50,9 @@ class GPT2Policy(Policy): target_module=col_nn.DropoutForParallelInput, ), ]) + if self.shard_config.enable_sequence_parallelism: + policy[GPT2Model].method_replacement = {"forward": gpt2_sequence_parallel_forward_fn(self.shard_config)} + policy[GPT2Block] = ModulePolicyDescription(attribute_replacement={ "attn.embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, "attn.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, @@ -120,6 +124,11 @@ class GPT2Policy(Policy): policy[GPT2Attention] = ModulePolicyDescription(method_replacement={ 'forward': get_gpt2_flash_attention_forward(), }) + + if self.shard_config.enable_sequence_parallelism: + suffix_list = ["attn.c_attn", "attn.c_proj", "mlp.c_fc", "mlp.c_proj"] + self.append_seq_parallel_to_policy(suffix_list=suffix_list, module_policy_description=policy[GPT2Block]) + return policy def postprocess(self): diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 0c28f115d..a36e878c6 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -28,6 +28,7 @@ class ShardConfig: enable_all_optimization: bool = False enable_flash_attention: bool = False enable_jit_fused: bool = False + enable_sequence_parallelism: bool = False # pipeline_parallel_size: int # data_parallel_size: int diff --git a/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py b/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py index b45cd172c..ae6a1dc90 100644 --- a/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py +++ b/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py @@ -53,8 +53,7 @@ def rearrange(tensor: torch.Tensor, dim: int): return rearanged_tensor -@parameterize('lazy_init', [False, True]) -def check_linear_conv_1d_col(lazy_init: bool): +def check_linear_conv_1d_col(lazy_init: bool, seq_parallel: bool): ctx = LazyInitContext() if lazy_init else nullcontext() linear = Conv1D(192, 48).cuda() with ctx: @@ -62,6 +61,7 @@ def check_linear_conv_1d_col(lazy_init: bool): linear_conv_col = GPT2FusedLinearConv1D_Col.from_native_module(linear_copy, process_group=None, gather_output=True, + seq_parallel=seq_parallel, n_fused=3) assert linear.weight.shape == torch.Size([48, 192]) @@ -76,10 +76,11 @@ def check_linear_conv_1d_col(lazy_init: bool): linear.load_state_dict(linear_conv_col.state_dict()) # check computation correctness - x = torch.rand(4, 48).cuda() + x = torch.rand(1, 4, 48).cuda() out = linear(x) - gather_out = linear_conv_col(x) - assert_close(rearrange(out, 1), gather_out) + x_for_shard = x.expand_as(x.clone()) if seq_parallel is False else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()] + gather_out = linear_conv_col(x_for_shard) + assert_close(rearrange(out, -1), gather_out) # check backward correctness out.sum().backward() @@ -89,14 +90,16 @@ def check_linear_conv_1d_col(lazy_init: bool): assert_close(target_grad, linear_conv_col.weight.grad) -@parameterize('lazy_init', [False, True]) -def check_linear_conv_1d_row(lazy_init: bool): +def check_linear_conv_1d_row(lazy_init: bool, seq_parallel: bool): ctx = LazyInitContext() if lazy_init else nullcontext() linear = Conv1D(192, 48).cuda() with ctx: linear_copy = Conv1D(192, 48).cuda() - linear_row = GPT2FusedLinearConv1D_Row.from_native_module(linear_copy, process_group=None, parallel_input=False) + linear_row = GPT2FusedLinearConv1D_Row.from_native_module(linear_copy, + process_group=None, + parallel_input=False, + seq_parallel=seq_parallel) assert linear.weight.shape == torch.Size([48, 192]) assert linear_row.weight.shape == torch.Size([24, 192]) @@ -109,10 +112,11 @@ def check_linear_conv_1d_row(lazy_init: bool): linear.load_state_dict(linear_row.state_dict()) # check computation correctness - x = torch.rand(4, 48).cuda() + x = torch.rand(1, 4, 48).cuda() out = linear(x) gather_out = linear_row(x) - assert_close(out, gather_out) + target_out = out if seq_parallel is False else torch.chunk(out.clone(), 2, dim=1)[dist.get_rank()] + assert_close(target_out, gather_out) # check backward correctness out.sum().backward() @@ -123,12 +127,18 @@ def check_linear_conv_1d_row(lazy_init: bool): assert_close(target_grad, linear_row.weight.grad) +@parameterize('lazy_init', [False, True]) +@parameterize('seq_parallel', [False, True]) +def check_gpt2_qkv_fused_linear_1d(lazy_init: bool, seq_parallel: bool): + check_linear_conv_1d_col(lazy_init, seq_parallel) + check_linear_conv_1d_row(lazy_init, seq_parallel) + + def run_dist(rank, world_size, port): colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') # test for linear conv - check_linear_conv_1d_col() - check_linear_conv_1d_row() + check_gpt2_qkv_fused_linear_1d() @rerun_if_address_is_in_use() diff --git a/tests/test_shardformer/test_layer/test_linear_1d.py b/tests/test_shardformer/test_layer/test_linear_1d.py index aa75879e0..3ad8f14b9 100644 --- a/tests/test_shardformer/test_layer/test_linear_1d.py +++ b/tests/test_shardformer/test_layer/test_linear_1d.py @@ -12,13 +12,16 @@ from colossalai.tensor.d_tensor import is_distributed_tensor from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -@parameterize('lazy_init', [False, True]) -def check_linear_1d_col(lazy_init: bool): +def check_linear_1d_col(lazy_init: bool, seq_parallel: bool, overlap: bool): ctx = LazyInitContext() if lazy_init else nullcontext() linear = nn.Linear(32, 128).cuda() with ctx: linear_copy = nn.Linear(32, 128).cuda() - linear_col = Linear1D_Col.from_native_module(linear_copy, process_group=None, gather_output=True) + linear_col = Linear1D_Col.from_native_module(linear_copy, + process_group=None, + gather_output=True, + seq_parallel=seq_parallel, + overlap=overlap) # ensure that the parameters are distributed assert is_distributed_tensor(linear_col.weight) @@ -35,10 +38,11 @@ def check_linear_1d_col(lazy_init: bool): linear_col.load_state_dict(linear.state_dict()) # check computation correctness - x = torch.rand(4, 32).cuda() + # [batch_size, seq_len, hidden_size] + x = torch.rand(2, 4, 32).cuda() x_for_unshard = x.expand_as(x.clone()) x_for_unshard.requires_grad_(True) - x_for_shard = x.expand_as(x.clone()) + x_for_shard = x.expand_as(x.clone()) if seq_parallel is False else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()] x_for_shard.requires_grad_(True) out = linear(x_for_unshard) @@ -56,17 +60,21 @@ def check_linear_1d_col(lazy_init: bool): # check the input gradients assert x_for_shard.grad is not None assert x_for_unshard.grad is not None - assert_close(x_for_unshard.grad, x_for_shard.grad) + target_unshard_gard = x_for_unshard.grad if seq_parallel is False else torch.chunk( + x_for_unshard.grad.clone(), 2, dim=1)[dist.get_rank()] + assert_close(target_unshard_gard, x_for_shard.grad) -@parameterize('lazy_init', [False, True]) -def check_linear_1d_row(lazy_init: bool): +def check_linear_1d_row(lazy_init: bool, seq_parallel: bool): ctx = LazyInitContext() if lazy_init else nullcontext() linear = nn.Linear(32, 128).cuda() with ctx: linear_copy = nn.Linear(32, 128).cuda() - linear_row = Linear1D_Row.from_native_module(linear_copy, process_group=None, parallel_input=False) + linear_row = Linear1D_Row.from_native_module(linear_copy, + process_group=None, + parallel_input=False, + seq_parallel=seq_parallel) assert linear_row.weight.shape == torch.Size([128, 16]) assert linear_row.bias.shape == torch.Size([128]) @@ -77,7 +85,8 @@ def check_linear_1d_row(lazy_init: bool): linear_row.load_state_dict(linear.state_dict()) # check computation correctness - x = torch.rand(4, 32).cuda() + # [batch_size, seq_len, hidden_size] + x = torch.rand(2, 4, 32).cuda() x_for_unshard = x.expand_as(x.clone()) x_for_unshard.requires_grad_(True) x_for_shard = x.expand_as(x.clone()) @@ -86,7 +95,8 @@ def check_linear_1d_row(lazy_init: bool): # run forward out = linear(x_for_unshard) gather_out = linear_row(x_for_shard) - assert_close(out, gather_out) + target_out = out if seq_parallel is False else torch.chunk(out.clone(), 2, dim=1)[dist.get_rank()] + assert_close(target_out, gather_out) # check backward correctness out.sum().backward() @@ -102,8 +112,7 @@ def check_linear_1d_row(lazy_init: bool): assert_close(x_for_unshard.grad, x_for_shard.grad) -@parameterize('lazy_init', [False, True]) -def check_linear_col_plus_row(lazy_init: bool): +def check_linear_col_plus_row(lazy_init: bool, seq_parallel: bool, overlap: bool): ctx = LazyInitContext() if lazy_init else nullcontext() linear_1 = nn.Linear(32, 128).cuda() @@ -112,8 +121,15 @@ def check_linear_col_plus_row(lazy_init: bool): with ctx: linear_1_copy = nn.Linear(32, 128).cuda() linear_2_copy = nn.Linear(128, 32).cuda() - linear_col = Linear1D_Col.from_native_module(linear_1_copy, process_group=None, gather_output=False) - linear_row = Linear1D_Row.from_native_module(linear_2_copy, process_group=None, parallel_input=True) + linear_col = Linear1D_Col.from_native_module(linear_1_copy, + process_group=None, + gather_output=False, + seq_parallel=seq_parallel, + overlap=overlap) + linear_row = Linear1D_Row.from_native_module(linear_2_copy, + process_group=None, + parallel_input=True, + seq_parallel=seq_parallel) linear_1.load_state_dict(linear_col.state_dict()) linear_col.load_state_dict(linear_1.state_dict()) @@ -121,16 +137,18 @@ def check_linear_col_plus_row(lazy_init: bool): linear_row.load_state_dict(linear_2.state_dict()) # check computation correctness - x = torch.rand(4, 32).cuda() + # [batch_size, seq_len, hidden_size] + x = torch.rand(2, 4, 32).cuda() x_for_unshard = x.expand_as(x.clone()) x_for_unshard.requires_grad_(True) - x_for_shard = x.expand_as(x.clone()) + x_for_shard = x.expand_as(x.clone()) if seq_parallel is False else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()] x_for_shard.requires_grad_(True) # run forward unshard_out = linear_2(linear_1(x_for_unshard)) shard_out = linear_row(linear_col(x_for_shard)) - assert_close(unshard_out, shard_out) + target_out = unshard_out if seq_parallel is False else torch.chunk(unshard_out.clone(), 2, dim=1)[dist.get_rank()] + assert_close(target_out, shard_out) # check backward correctness unshard_out.sum().backward() @@ -143,19 +161,28 @@ def check_linear_col_plus_row(lazy_init: bool): # check the input gradients assert x_for_shard.grad is not None assert x_for_unshard.grad is not None - assert_close(x_for_unshard.grad, x_for_shard.grad) + target_unshard_gard = x_for_unshard.grad if seq_parallel is False else torch.chunk( + x_for_unshard.grad.clone(), 2, dim=1)[dist.get_rank()] + assert_close(target_unshard_gard, x_for_shard.grad) -def run_dist(rank, world_size, port): +@parameterize('lazy_init', [False, True]) +@parameterize('seq_parallel', [False, True]) +@parameterize('overlap', [False, True]) +def run_dist_linear_test(lazy_init, seq_parallel, overlap): + check_linear_1d_col(lazy_init, seq_parallel, overlap) + check_linear_1d_row(lazy_init, seq_parallel) + check_linear_col_plus_row(lazy_init, seq_parallel, overlap) + + +def check_dist_linear(rank, world_size, port): colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - check_linear_1d_col() - check_linear_1d_row() - check_linear_col_plus_row() + run_dist_linear_test() @rerun_if_address_is_in_use() def test_linear(): - spawn(run_dist, nprocs=2) + spawn(check_dist_linear, nprocs=2) if __name__ == '__main__': diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index 921af2a8b..7e1e6f2fe 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -1,4 +1,5 @@ import copy +import math from contextlib import nullcontext from typing import Any, Callable, Dict, List, Optional @@ -25,6 +26,7 @@ def build_model(model_fn, enable_tensor_parallelism=True, enable_flash_attention=False, enable_jit_fused=False, + enable_sequence_parallelism=False, use_lazy_init: bool = False): # create new model ctx = LazyInitContext() if use_lazy_init else nullcontext() @@ -38,7 +40,8 @@ def build_model(model_fn, shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization, enable_tensor_parallelism=enable_tensor_parallelism, enable_flash_attention=enable_flash_attention, - enable_jit_fused=enable_jit_fused) + enable_jit_fused=enable_jit_fused, + enable_sequence_parallelism=enable_sequence_parallelism) model_copy = copy.deepcopy(org_model) shard_former = ShardFormer(shard_config=shard_config) sharded_model, shared_params = shard_former.optimize(model_copy) @@ -135,6 +138,16 @@ def run_forward_backward_with_hybrid_plugin(org_model: Module, sharded_model: Mo return loss data = data_gen_fn() + + if booster.plugin.enable_sequence_parallelism and booster.plugin.tp_size != 0: + seq_len = data['input_ids'].shape[1] + lcm = booster.plugin.tp_size * seq_len // math.gcd(booster.plugin.tp_size, seq_len) + times = lcm // seq_len + input_shape = data['input_ids'].shape + for k, v in data.items(): + if v.shape == input_shape: + data[k] = v.repeat(1, times) + sharded_model.train() if booster.plugin.stage_manager is not None: for k, v in data.items(): diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index ca086bf12..c97702cbb 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -106,6 +106,13 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'enable_all_optimization': True, 'use_lazy_init': False, 'precision': 'fp32', +}, { + 'tp_size': 4, + 'pp_size': 1, + 'enable_all_optimization': False, + 'use_lazy_init': True, + 'enable_sequence_parallelism': True, + 'precision': 'fp32', }]) @clear_cache_before_run() def run_gpt2_test(test_config): From 6ef33f75aa05390894e411296acf8db8a0b55118 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Wed, 16 Aug 2023 16:11:57 +0800 Subject: [PATCH 02/33] [shardformer] support DDP in HybridPlugin/add tp+dp tests (#4446) * support DDP for HybridPlugin/add tp+dp tests * add docstring for HybridParallelPlugin --- .../booster/plugin/hybrid_parallel_plugin.py | 129 ++++++++++++++---- tests/test_shardformer/test_model/_utils.py | 13 ++ .../test_model/test_shard_bert.py | 17 ++- .../test_model/test_shard_bloom.py | 19 +-- .../test_model/test_shard_chatglm.py | 19 +-- .../test_model/test_shard_gpt2.py | 21 ++- .../test_model/test_shard_llama.py | 20 +-- .../test_model/test_shard_opt.py | 20 +-- .../test_model/test_shard_t5.py | 20 +-- .../test_model/test_shard_vit.py | 21 +-- 10 files changed, 199 insertions(+), 100 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 3d45a9112..00c714fe4 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -6,7 +6,8 @@ import numpy as np import torch import torch.distributed as dist from torch.distributed import ProcessGroup -from torch.nn import Module +from torch.nn import Module, SyncBatchNorm +from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.utils.data import DataLoader @@ -28,7 +29,8 @@ DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2 class HybridParallelModule(ModelWrapper): - def __init__(self, module: Module, precision: str, shard_config: ShardConfig, dp_group: ProcessGroup) -> None: + def __init__(self, module: Module, precision: str, shard_config: ShardConfig, dp_group: ProcessGroup, use_ddp: bool, + ddp_config: dict) -> None: self.stage_manager = shard_config.pipeline_stage_manager self.dp_group = dp_group shardformer = ShardFormer(shard_config) @@ -45,7 +47,15 @@ class HybridParallelModule(ModelWrapper): module = module.to(dtype=torch.bfloat16).cuda() else: module = module.cuda() # train without AMP - # TODO(ver217): support TP+DP + + if use_ddp: + + # convert model to sync bn + module = SyncBatchNorm.convert_sync_batchnorm(module, dp_group) + + # wrap the model with PyTorch DDP + module = DDP(module, process_group=dp_group, **ddp_config) + super().__init__(module) def sync_shared_params(self): @@ -68,6 +78,12 @@ class HybridParallelModule(ModelWrapper): dist.all_reduce(p.grad, group=self.dp_group) p.grad.div_(self.dp_group.size()) + def unwrap(self): + module = super().unwrap() + if isinstance(module, DDP): + module = module.module + return module + def init_pipeline_optimizer(optim: Optimizer, model: Module): params = set(model.parameters()) @@ -140,29 +156,81 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer): class HybridParallelPlugin(PipelinePluginBase): + """ + Plugin for Hybrid Parallel Training. + Tensor parallel, pipeline parallel and data parallel(DDP/ZeRO) can be picked and combined in this plugin. + The size of tp and pp should be passed in by user, then the size of dp is automatically calculated from dp_size = world_size / (tp_size * pp_size). + + Example: + >>> from colossalai.booster import Booster + >>> from colossalai.booster.plugin import HybridParallelPlugin + + >>> model, train_dataset, optimizer, criterion = ... + >>> plugin = HybridParallelPlugin(tp_size=2, pp_size=2) + + >>> train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8) + >>> booster = Booster(plugin=plugin) + >>> model, optimizer, criterion, train_dataloader, _ = booster.boost(model, optimizer, criterion, train_dataloader) + + Args: + tp_size (int): The size of tensor parallelism. Tensor parallelism will not be used when tp_size is set to 1. + pp_size (int): The number of pipeline stages in pipeline parallelism. Pipeline parallelism will not be used when pp_size is set to 1. + precision (str, optional): Specifies the precision of parameters during training. + Auto-mixied precision will be used when this argument is set to 'fp16' or 'bf16', otherwise model is trained with 'fp32'. + Defaults to 'fp16'. + zero_stage (int, optional): The stage of ZeRO for data parallelism. Can only be choosed from [0, 1, 2]. + When set to 0, ZeRO will not be used. Defaults to 0. + cpu_offload (bool, optional): Whether to open cpu_offload when using ZeRO. Defaults to False. + enable_all_optimization (bool, optional): Whether to switch on all the optimizations supported by Shardformer. + Currently all the optimization methods include fused normalization, flash attention and JIT. + Defaults to False. + enable_fused_normalization (bool, optional): Whether to switch on fused normalization. Defaults to False. + enable_flash_attention (bool, optional): Whether to switch on flash attention. Defaults to False. + enable_jit_fused (bool, optional): Whether to switch on JIT. Default to Falase. + num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None. + initial_scale (float, optional): The initial loss scale of AMP. Defaults to 2**16. + min_scale (float, optional): The minimum loss scale of AMP. Defaults to 1. + growth_factor (float, optional): The multiplication factor for increasing loss scale when using AMP. Defaults to 2. + backoff_factor (float, optional): The multiplication factor for decreasing loss scale when using AMP. Defaults to 0.5. + growth_interval (int, optional): The number of steps to increase loss scale when no overflow occurs when using AMP. Defaults to 1000. + hysteresis (int, optional): The number of overflows before decreasing loss scale when using AMP. Defaults to 2. + max_scale (float, optional): The maximum loss scale of AMP. Defaults to 2**32. + max_norm (float, optional): Maximum norm for gradient clipping. Defaults to 0. + broadcast_buffers (bool, optional): Whether to broadcast buffers in the beginning of training. Only for usage of DDP. Defaults to True. + bucket_cap_mb (int, optional): The bucket size in MB. Only for usage of DDP. Defaults to 25. + find_unused_parameters (bool, optional): Whether to find unused parameters. Only for usage of DDP. Defaults to False. + check_reduction (bool, optional): Whether to check reduction. Only for usage of DDP. Defaults to False. + gradient_as_bucket_view (bool, optional): Whether to use gradient as bucket view. Only for usage of DDP. Defaults to False. + static_graph (bool, optional): Whether to use static graph. Only for usage of DDP. Defaults to False. + """ + + def __init__(self, + tp_size: int, + pp_size: int, + precision: str = 'fp16', + zero_stage: int = 0, + cpu_offload: bool = False, + enable_all_optimization: bool = False, + enable_fused_normalization: bool = False, + enable_flash_attention: bool = False, + enable_jit_fused: bool = False, + enable_sequence_parallelism: bool = False, + num_microbatches: Optional[int] = None, + initial_scale: float = 2**16, + min_scale: float = 1, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + max_scale: float = 2**32, + max_norm: float = 0, + broadcast_buffers=True, + bucket_cap_mb=25, + find_unused_parameters=False, + check_reduction=False, + gradient_as_bucket_view=False, + static_graph=False) -> None: - def __init__( - self, - tp_size: int, - pp_size: int, - precision: str = 'fp16', - zero_stage: int = 0, - cpu_offload: bool = False, - enable_all_optimization: bool = False, - enable_fused_normalization: bool = False, - enable_flash_attention: bool = False, - enable_jit_fused: bool = False, - enable_sequence_parallelism: bool = False, - num_microbatches: Optional[int] = None, - initial_scale: float = 2**16, - min_scale: float = 1, - growth_factor: float = 2, - backoff_factor: float = 0.5, - growth_interval: int = 1000, - hysteresis: int = 2, - max_scale: float = 2**32, - max_norm: float = 0, - ) -> None: super().__init__() assert dist.get_world_size() % ( tp_size * pp_size @@ -208,6 +276,13 @@ class HybridParallelPlugin(PipelinePluginBase): min_scale=min_scale, max_scale=max_scale, ) + + self.ddp_config = dict(broadcast_buffers=broadcast_buffers, + bucket_cap_mb=bucket_cap_mb, + find_unused_parameters=find_unused_parameters, + check_reduction=check_reduction, + gradient_as_bucket_view=gradient_as_bucket_view, + static_graph=static_graph) self.max_norm = max_norm @property @@ -241,7 +316,9 @@ class HybridParallelPlugin(PipelinePluginBase): lr_scheduler: Optional[LRScheduler] = None, ) -> Tuple[Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: if not isinstance(model, ModelWrapper): - model = HybridParallelModule(model, self.precision, self.shard_config, self.dp_group) + use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0 + model = HybridParallelModule(model, self.precision, self.shard_config, self.dp_group, use_ddp, + self.ddp_config) if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): if self.zero_stage == 0: if self.precision in ['fp16', 'bf16']: diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index 7e1e6f2fe..789b3b24e 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -13,6 +13,7 @@ from torch.optim import Adam, Optimizer from colossalai.booster import Booster from colossalai.booster.plugin import HybridParallelPlugin +from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelModule from colossalai.lazy import LazyInitContext from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer import ShardConfig, ShardFormer @@ -259,3 +260,15 @@ def check_grad(org_model: Module, assert torch.allclose( org_grad.float(), shard_grad.float(), rtol=rtol, atol=atol ), f"error attribute '{suffix}', orgin model grad is not equal to shard model grad\n{org_grad}\n{shard_grad}" + + +def unwrap_model(module: Module, + base_model_class_name: Optional[str] = None, + base_model_attribute_name: Optional[str] = None): + if isinstance(module, HybridParallelModule): + module = module.unwrap() + if base_model_class_name is None: + return module + if module.__class__.__name__ == base_model_class_name: + return module + return getattr(module, base_model_attribute_name, None) diff --git a/tests/test_shardformer/test_model/test_shard_bert.py b/tests/test_shardformer/test_model/test_shard_bert.py index 0a24e46d2..49de9cc03 100644 --- a/tests/test_shardformer/test_model/test_shard_bert.py +++ b/tests/test_shardformer/test_model/test_shard_bert.py @@ -15,6 +15,7 @@ from tests.test_shardformer.test_model._utils import ( check_output_hidden_state, check_weight, run_forward_backward_with_hybrid_plugin, + unwrap_model, ) @@ -44,13 +45,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) - # unwrap model - if org_model.__class__.__name__ == 'BertModel': - bert = org_model - sharded_bert = sharded_model.unwrap() - else: - bert = org_model.bert - sharded_bert = sharded_model.unwrap().bert + + bert = unwrap_model(org_model, 'BertModel', 'bert') + sharded_bert = unwrap_model(sharded_model, 'BertModel', 'bert') col_layer_for_check = ['encoder.layer[0].output.dense'] row_layer_for_check = ['embeddings.word_embeddings', 'encoder.layer[0].intermediate.dense'] @@ -98,6 +95,12 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'enable_all_optimization': True, 'use_lazy_init': False, 'precision': 'fp32', +}, { + 'tp_size': 2, + 'pp_size': 1, + 'enable_all_optimization': True, + 'use_lazy_init': False, + 'precision': 'fp32' }]) def run_bert_test(test_config): diff --git a/tests/test_shardformer/test_model/test_shard_bloom.py b/tests/test_shardformer/test_model/test_shard_bloom.py index ed0d1d8e4..af014a858 100644 --- a/tests/test_shardformer/test_model/test_shard_bloom.py +++ b/tests/test_shardformer/test_model/test_shard_bloom.py @@ -13,6 +13,7 @@ from tests.test_shardformer.test_model._utils import ( check_output_hidden_state, check_weight, run_forward_backward_with_hybrid_plugin, + unwrap_model, ) @@ -46,12 +47,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) # unwrap model - if org_model.__class__.__name__ == 'BloomModel': - bloom = org_model - sharded_bloom = sharded_model.unwrap() - else: - bloom = org_model.transformer - sharded_bloom = sharded_model.unwrap().transformer + bloom = unwrap_model(org_model, 'BloomModel', 'transformer') + sharded_bloom = unwrap_model(sharded_model, 'BloomModel', 'transformer') # check grad row_layer_for_check = ['h[0].self_attention.query_key_value', 'word_embeddings'] @@ -97,12 +94,16 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'pp_size': 1, 'enable_all_optimization': True, 'use_lazy_init': False, - 'precision': 'fp32', + 'precision': 'fp32' +}, { + 'tp_size': 2, + 'pp_size': 1, + 'enable_all_optimization': True, + 'use_lazy_init': False, + 'precision': 'fp32' }]) def run_bloom_test(test_config): - # TODO(baizhou): add test_config for TP+DP after supporting & debugging it - sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): diff --git a/tests/test_shardformer/test_model/test_shard_chatglm.py b/tests/test_shardformer/test_model/test_shard_chatglm.py index bb7775904..210f775b5 100644 --- a/tests/test_shardformer/test_model/test_shard_chatglm.py +++ b/tests/test_shardformer/test_model/test_shard_chatglm.py @@ -14,6 +14,7 @@ from tests.test_shardformer.test_model._utils import ( check_output_hidden_state, check_weight, run_forward_backward_with_hybrid_plugin, + unwrap_model, ) @@ -48,12 +49,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) # unwrap model - if org_model.__class__.__name__ == 'ChatGLMModel': - chatglm_model = org_model - shard_chatglm_model = sharded_model.unwrap() - else: - chatglm_model = org_model.transformer - shard_chatglm_model = sharded_model.unwrap().transformer + chatglm_model = unwrap_model(org_model, 'ChatGLMModel', 'transformer') + shard_chatglm_model = unwrap_model(sharded_model, 'ChatGLMModel', 'transformer') # check grad row_layer_for_check = ['encoder.layers[0].self_attention.query_key_value', 'embedding.word_embeddings'] @@ -121,12 +118,16 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'pp_size': 1, 'enable_all_optimization': True, 'use_lazy_init': False, - 'precision': 'fp32', + 'precision': 'fp32' +}, { + 'tp_size': 2, + 'pp_size': 1, + 'enable_all_optimization': True, + 'use_lazy_init': False, + 'precision': 'fp32' }]) def run_chatglm_test(test_config): - # TODO(baizhou): add test_config for TP+DP after supporting & debugging it - sub_model_zoo = model_zoo.get_sub_registry('transformers_chatglm') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index c97702cbb..97295f72f 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -3,7 +3,6 @@ import torch from torch import distributed as dist import colossalai -from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelModule from colossalai.logging import disable_existing_loggers from colossalai.tensor.d_tensor.api import clear_layout_converter from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn @@ -15,6 +14,7 @@ from tests.test_shardformer.test_model._utils import ( check_output_hidden_state, check_weight, run_forward_backward_with_hybrid_plugin, + unwrap_model, ) @@ -48,16 +48,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) - def unwrap(module): - if isinstance(module, HybridParallelModule): - module = module.unwrap() - if module.__class__.__name__ == 'GPT2Model': - return module - return module.transformer - # unwrap model - gpt2 = unwrap(org_model) - sharded_gpt2 = unwrap(sharded_model) + gpt2 = unwrap_model(org_model, 'GPT2Model', 'transformer') + sharded_gpt2 = unwrap_model(sharded_model, 'GPT2Model', 'transformer') col_layer_for_check = ['h[0].mlp.c_fc'] row_layer_for_check = ['wte', 'h[0].mlp.c_proj'] @@ -106,6 +99,12 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'enable_all_optimization': True, 'use_lazy_init': False, 'precision': 'fp32', +}, { + 'tp_size': 2, + 'pp_size': 1, + 'enable_all_optimization': True, + 'use_lazy_init': False, + 'precision': 'fp32', }, { 'tp_size': 4, 'pp_size': 1, @@ -117,8 +116,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, @clear_cache_before_run() def run_gpt2_test(test_config): - # TODO(baizhou): add test_config for TP+DP after supporting & debugging it - sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 30ebdfbe5..a433567b3 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -16,6 +16,7 @@ from tests.test_shardformer.test_model._utils import ( check_output_hidden_state, check_weight, run_forward_backward_with_hybrid_plugin, + unwrap_model, ) os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' @@ -52,12 +53,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) # unwrap model - if org_model.__class__.__name__ == 'LlamaModel': - llama_model = org_model - shard_llama_model = sharded_model.unwrap() - else: - llama_model = org_model.model - shard_llama_model = sharded_model.unwrap().model + llama_model = unwrap_model(org_model, 'LlamaModel', 'model') + shard_llama_model = unwrap_model(sharded_model, 'LlamaModel', 'model') # check grad row_layer_for_check = ['layers[0].self_attn.q_proj', 'embed_tokens'] @@ -128,13 +125,18 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'tp_size': 1, 'pp_size': 4, 'num_microbatches': 4, + 'enable_all_optimization': False, 'use_lazy_init': False, - 'precision': 'fp32', + 'precision': 'fp32' +}, { + 'tp_size': 2, + 'pp_size': 1, + 'enable_all_optimization': True, + 'use_lazy_init': False, + 'precision': 'fp32' }]) def run_llama_test(test_config): - # TODO(baizhou): add test_config for TP+DP after supporting & debugging it - sub_model_zoo = model_zoo.get_sub_registry('transformers_llama') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): diff --git a/tests/test_shardformer/test_model/test_shard_opt.py b/tests/test_shardformer/test_model/test_shard_opt.py index 8d1154d82..2fb14903b 100644 --- a/tests/test_shardformer/test_model/test_shard_opt.py +++ b/tests/test_shardformer/test_model/test_shard_opt.py @@ -16,6 +16,7 @@ from tests.test_shardformer.test_model._utils import ( check_output_hidden_state, check_weight, run_forward_backward_with_hybrid_plugin, + unwrap_model, ) os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' @@ -51,12 +52,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) # unwrap model - if org_model.__class__.__name__ == 'OPTModel': - opt_model = org_model - shard_opt_model = sharded_model.unwrap() - else: - opt_model = org_model.model - shard_opt_model = sharded_model.unwrap().model + opt_model = unwrap_model(org_model, 'OPTModel', 'model') + shard_opt_model = unwrap_model(sharded_model, 'OPTModel', 'model') # check grad row_layer_for_check = ['decoder.layers[0].self_attn.q_proj', 'decoder.embed_tokens'] # 'decoder.embed_tokens' @@ -123,14 +120,17 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'pp_size': 1, 'enable_all_optimization': True, 'use_lazy_init': False, - 'precision': 'fp32', + 'precision': 'fp32' +}, { + 'tp_size': 2, + 'pp_size': 1, + 'enable_all_optimization': True, + 'use_lazy_init': False, + 'precision': 'fp32' }]) def run_opt_test(test_config): - # TODO(baizhou): add test_config for TP+DP after supporting & debugging it - sub_model_zoo = model_zoo.get_sub_registry('transformers_opt') - for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) diff --git a/tests/test_shardformer/test_model/test_shard_t5.py b/tests/test_shardformer/test_model/test_shard_t5.py index 066f7ee81..234ce812a 100644 --- a/tests/test_shardformer/test_model/test_shard_t5.py +++ b/tests/test_shardformer/test_model/test_shard_t5.py @@ -1,5 +1,6 @@ import pytest import torch +from torch.nn.parallel import DistributedDataParallel as DDP import colossalai from colossalai.logging import disable_existing_loggers @@ -14,6 +15,7 @@ from tests.test_shardformer.test_model._utils import ( check_output_hidden_state, check_weight, run_forward_backward_with_hybrid_plugin, + unwrap_model, ) @@ -48,8 +50,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) # unwrap model - t5 = org_model - sharded_t5 = sharded_model.unwrap() + t5 = unwrap_model(org_model) + sharded_t5 = unwrap_model(sharded_model) row_layer_for_check = ['shared', 'encoder.block[0].layer[0].SelfAttention.q'] @@ -99,17 +101,19 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'tp_size': 1, 'pp_size': 4, 'num_microbatches': 4, + 'enable_all_optimization': False, 'use_lazy_init': False, - 'precision': 'fp32', + 'precision': 'fp32' +}, { + 'tp_size': 2, + 'pp_size': 1, + 'enable_all_optimization': True, + 'use_lazy_init': False, + 'precision': 'fp32' }]) @clear_cache_before_run() def run_t5_test(test_config): - # TODO(baizhou): add plugin_config for TP+DP after supporting & debugging it - # {'tp_size': 2, 'pp_size': 1, 'enable_fused_normalization': True} - - # TODO(baizhou): add test_config for flash attention & jit operator after supporting - sub_model_zoo = model_zoo.get_sub_registry('transformers_t5') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): diff --git a/tests/test_shardformer/test_model/test_shard_vit.py b/tests/test_shardformer/test_model/test_shard_vit.py index 18df8ef55..b9d303841 100644 --- a/tests/test_shardformer/test_model/test_shard_vit.py +++ b/tests/test_shardformer/test_model/test_shard_vit.py @@ -14,6 +14,7 @@ from tests.test_shardformer.test_model._utils import ( check_output_hidden_state, check_weight, run_forward_backward_with_hybrid_plugin, + unwrap_model, ) @@ -48,12 +49,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) # unwrap model - if org_model.__class__.__name__ == 'ViTModel': - vit_model = org_model - shard_vit_model = sharded_model.unwrap() - else: - vit_model = org_model.vit - shard_vit_model = sharded_model.unwrap().vit + vit_model = unwrap_model(org_model, 'ViTModel', 'vit') + shard_vit_model = unwrap_model(sharded_model, 'ViTModel', 'vit') # check grad row_layer_for_check = ['encoder.layer[0].attention.attention.query', 'embeddings.patch_embeddings.projection'] @@ -120,15 +117,19 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'pp_size': 1, 'enable_all_optimization': True, 'use_lazy_init': False, - 'precision': 'fp32', + 'precision': 'fp32' +}, { + 'tp_size': 2, + 'pp_size': 1, + 'enable_all_optimization': True, + 'use_lazy_init': False, + 'precision': 'fp32' }]) def run_vit_test(test_config): - # TODO(baizhou): add test_config for TP+DP after supporting & debugging it - # TODO(baizhou): fix bug when settign lazy_init for Conv2D Layers in ViT models + # TODO: fix bug when settign lazy_init for Conv2D Layers in ViT models sub_model_zoo = model_zoo.get_sub_registry('transformers_vit') - for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) From 26e29d58f0525ff573d6a2eeae328e0a4d7f9a68 Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Wed, 16 Aug 2023 18:56:52 +0800 Subject: [PATCH 03/33] [devops] add large-scale distributed test marker (#4452) * [test] remove cpu marker * [test] remove gpu marker * [test] update pytest markers * [ci] update unit test ci --- .github/workflows/build_on_pr.yml | 2 +- .../compatiblity_test_on_dispatch.yml | 2 +- .github/workflows/compatiblity_test_on_pr.yml | 2 +- .../compatiblity_test_on_schedule.yml | 2 +- applications/Chat/tests/test_dataset.py | 79 ++++++------- applications/Chat/tests/test_models.py | 105 +++++++----------- pytest.ini | 6 +- tests/test_config/test_load_config.py | 1 - tests/test_context/test_hybrid_parallel.py | 1 - tests/test_data/test_cifar10_dataset.py | 3 +- tests/test_data/test_data_parallel_sampler.py | 1 - .../test_deterministic_dataloader.py | 1 - .../test_activation_checkpointing.py | 1 - 13 files changed, 81 insertions(+), 125 deletions(-) diff --git a/.github/workflows/build_on_pr.yml b/.github/workflows/build_on_pr.yml index 8a1bc8e11..4c7e08e57 100644 --- a/.github/workflows/build_on_pr.yml +++ b/.github/workflows/build_on_pr.yml @@ -208,7 +208,7 @@ jobs: - name: Execute Unit Testing run: | - CURL_CA_BUNDLE="" PYTHONPATH=$PWD pytest --testmon --testmon-cov=. --durations=10 tests/ + CURL_CA_BUNDLE="" PYTHONPATH=$PWD pytest -m "not largedist" --testmon --testmon-cov=. --durations=10 tests/ env: DATA: /data/scratch/cifar-10 NCCL_SHM_DISABLE: 1 diff --git a/.github/workflows/compatiblity_test_on_dispatch.yml b/.github/workflows/compatiblity_test_on_dispatch.yml index 1778d64ee..63c0fbbb9 100644 --- a/.github/workflows/compatiblity_test_on_dispatch.yml +++ b/.github/workflows/compatiblity_test_on_dispatch.yml @@ -44,7 +44,7 @@ jobs: name: Test for PyTorch Compatibility needs: matrix_preparation if: github.repository == 'hpcaitech/ColossalAI' - runs-on: [self-hosted, gpu] + runs-on: [self-hosted, 8-gpu] strategy: fail-fast: false matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}} diff --git a/.github/workflows/compatiblity_test_on_pr.yml b/.github/workflows/compatiblity_test_on_pr.yml index c0f45c65a..c9f84806b 100644 --- a/.github/workflows/compatiblity_test_on_pr.yml +++ b/.github/workflows/compatiblity_test_on_pr.yml @@ -35,7 +35,7 @@ jobs: name: Test for PyTorch Compatibility needs: matrix_preparation if: github.repository == 'hpcaitech/ColossalAI' - runs-on: [self-hosted, gpu] + runs-on: [self-hosted, 8-gpu] strategy: fail-fast: false matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}} diff --git a/.github/workflows/compatiblity_test_on_schedule.yml b/.github/workflows/compatiblity_test_on_schedule.yml index 15ac4f1a9..3f8fc9639 100644 --- a/.github/workflows/compatiblity_test_on_schedule.yml +++ b/.github/workflows/compatiblity_test_on_schedule.yml @@ -32,7 +32,7 @@ jobs: name: Test for PyTorch Compatibility needs: matrix_preparation if: github.repository == 'hpcaitech/ColossalAI' - runs-on: [self-hosted, gpu] + runs-on: [self-hosted, 8-gpu] strategy: fail-fast: false matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}} diff --git a/applications/Chat/tests/test_dataset.py b/applications/Chat/tests/test_dataset.py index 64ea1178c..1d9aa50e2 100644 --- a/applications/Chat/tests/test_dataset.py +++ b/applications/Chat/tests/test_dataset.py @@ -14,29 +14,43 @@ from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer SFT_DATASET = [ { - "instruction": "Provide a list of the top 10 most popular mobile games in Asia", - "input": "", - "output": "The top 10 most popular mobile games in Asia are:\n1) PUBG Mobile\n2) Pokemon Go\n3) Candy Crush Saga\n4) Free Fire\n5) Clash of Clans\n6) Mario Kart Tour\n7) Arena of Valor\n8) Fantasy Westward Journey\n9) Subway Surfers\n10) ARK Survival Evolved", - "id": 0 + "instruction": + "Provide a list of the top 10 most popular mobile games in Asia", + "input": + "", + "output": + "The top 10 most popular mobile games in Asia are:\n1) PUBG Mobile\n2) Pokemon Go\n3) Candy Crush Saga\n4) Free Fire\n5) Clash of Clans\n6) Mario Kart Tour\n7) Arena of Valor\n8) Fantasy Westward Journey\n9) Subway Surfers\n10) ARK Survival Evolved", + "id": + 0 }, { - "instruction": "Please provide an action plan for reducing carbon footprint on a corporate level", - "input": "", - "output": "An action plan for reducing carbon footprint on a corporate level could include switching to renewable energy sources (e.g., solar or wind powered), investing in energy efficient infrastructure, implementing green production practices (e.g., reduce water and paper usage, dispose of materials responsibly, recycle/reuse materials), reducing or eliminating travel, encouraging employees to be conscious of and reduce their own energy usage, and incentivizing employees to participate in sustainable work practices.", - "id": 1 + "instruction": + "Please provide an action plan for reducing carbon footprint on a corporate level", + "input": + "", + "output": + "An action plan for reducing carbon footprint on a corporate level could include switching to renewable energy sources (e.g., solar or wind powered), investing in energy efficient infrastructure, implementing green production practices (e.g., reduce water and paper usage, dispose of materials responsibly, recycle/reuse materials), reducing or eliminating travel, encouraging employees to be conscious of and reduce their own energy usage, and incentivizing employees to participate in sustainable work practices.", + "id": + 1 }, { - "instruction": "Write a persuasive email to your boss explaining why you should have a pay raise", - "input": "", - "output": "Dear [Boss], \n\nI hope this message finds you well. I am writing to request a pay raise.\n\nAs you know, I have been a dedicated and hardworking member of the team since I started working here [insert number] of months/years ago. My enthusiasm and passion for my job has remained consistent over the years, and I have always given 100% to my role. \n\nI understand that the current financial situation is challenging, however, I would sincerely appreciate you taking the time to consider my request. I believe that my dedication to the job and the value that I bring to the organization warrants a raise. I work diligently and am confident that I can continue to be an asset to the company. \n\nI hope my request is taken into account and I thank you in advance for your understanding. I look forward to our conversation. \n\nSincerely,\n[Your Name]", - "id": 2 + "instruction": + "Write a persuasive email to your boss explaining why you should have a pay raise", + "input": + "", + "output": + "Dear [Boss], \n\nI hope this message finds you well. I am writing to request a pay raise.\n\nAs you know, I have been a dedicated and hardworking member of the team since I started working here [insert number] of months/years ago. My enthusiasm and passion for my job has remained consistent over the years, and I have always given 100% to my role. \n\nI understand that the current financial situation is challenging, however, I would sincerely appreciate you taking the time to consider my request. I believe that my dedication to the job and the value that I bring to the organization warrants a raise. I work diligently and am confident that I can continue to be an asset to the company. \n\nI hope my request is taken into account and I thank you in advance for your understanding. I look forward to our conversation. \n\nSincerely,\n[Your Name]", + "id": + 2 }, ] PROMPT_DATASET = [ { - "instruction": "Edit this paragraph to make it more concise: \"Yesterday, I went to the store and bought some things. Then, I came home and put them away. After that, I went for a walk and met some friends.\"", - "id": 0 + "instruction": + "Edit this paragraph to make it more concise: \"Yesterday, I went to the store and bought some things. Then, I came home and put them away. After that, I went for a walk and met some friends.\"", + "id": + 0 }, { "instruction": "Write a descriptive paragraph about a memorable vacation you went on", @@ -71,9 +85,7 @@ def make_tokenizer(model: str): return tokenizer -def check_content(input_ids_stripped: torch.Tensor, - tokenizer: PreTrainedTokenizer, - model: str): +def check_content(input_ids_stripped: torch.Tensor, tokenizer: PreTrainedTokenizer, model: str): if model == "opt": # NOTE: Contrary to GPT2, OPT adds the EOS token to the beginning of every prompt. assert input_ids_stripped[0] == tokenizer.eos_token_id @@ -90,13 +102,10 @@ def check_content(input_ids_stripped: torch.Tensor, assert input_ids_stripped != tokenizer.mask_token_id -@pytest.mark.cpu @pytest.mark.parametrize("model", ["gpt2", "bloom", "opt", "llama"]) @pytest.mark.parametrize("max_length", [32, 1024]) @pytest.mark.parametrize("max_datasets_size", [2]) -def test_prompt_dataset(model: str, - max_datasets_size: int, - max_length: int): +def test_prompt_dataset(model: str, max_datasets_size: int, max_length: int): with tempfile.TemporaryDirectory() as tmp_dir: dataset_name = "prompt_dataset.json" with open(os.path.join(tmp_dir, dataset_name), "w") as f: @@ -119,19 +128,12 @@ def test_prompt_dataset(model: str, check_content(input_ids.masked_select(attention_mask), tokenizer, model) -@pytest.mark.cpu @pytest.mark.parametrize("model", ["gpt2", "bloom", "opt", "llama"]) -@pytest.mark.parametrize(["dataset_path", "subset"], [ - ("Anthropic/hh-rlhf", "harmless-base"), - ("Dahoas/rm-static", None) -]) +@pytest.mark.parametrize(["dataset_path", "subset"], [("Anthropic/hh-rlhf", "harmless-base"), + ("Dahoas/rm-static", None)]) @pytest.mark.parametrize("max_datasets_size", [32]) @pytest.mark.parametrize("max_length", [32, 1024]) -def test_reward_dataset(model: str, - dataset_path: str, - subset: Optional[str], - max_datasets_size: int, - max_length: int): +def test_reward_dataset(model: str, dataset_path: str, subset: Optional[str], max_datasets_size: int, max_length: int): data = load_dataset(dataset_path, data_dir=subset) assert max_datasets_size <= len(data["train"]) \ and max_datasets_size <= len(data["test"]) @@ -188,15 +190,11 @@ def test_reward_dataset(model: str, assert torch.all(r_mask) -@pytest.mark.cpu @pytest.mark.parametrize("model", ["gpt2", "bloom", "opt", "llama"]) @pytest.mark.parametrize("dataset_path", ["yizhongw/self_instruct", None]) @pytest.mark.parametrize("max_dataset_size", [2]) @pytest.mark.parametrize("max_length", [32, 1024]) -def test_sft_dataset(model: str, - dataset_path: Optional[str], - max_dataset_size: int, - max_length: int): +def test_sft_dataset(model: str, dataset_path: Optional[str], max_dataset_size: int, max_length: int): tokenizer = make_tokenizer(model) if dataset_path == "yizhongw/self_instruct": data = load_dataset(dataset_path, "super_natural_instructions") @@ -232,10 +230,7 @@ def test_sft_dataset(model: str, if __name__ == "__main__": - test_sft_dataset(model="bloom", - dataset_path="yizhongw/self_instruct", - max_dataset_size=2, - max_length=256) + test_sft_dataset(model="bloom", dataset_path="yizhongw/self_instruct", max_dataset_size=2, max_length=256) test_reward_dataset(model="gpt2", dataset_path="Anthropic/hh-rlhf", @@ -243,6 +238,4 @@ if __name__ == "__main__": max_datasets_size=8, max_length=256) - test_prompt_dataset(model="opt", - max_datasets_size=2, - max_length=128) + test_prompt_dataset(model="opt", max_datasets_size=2, max_length=128) diff --git a/applications/Chat/tests/test_models.py b/applications/Chat/tests/test_models.py index bd6b3e8a5..e96ff8bd7 100644 --- a/applications/Chat/tests/test_models.py +++ b/applications/Chat/tests/test_models.py @@ -15,16 +15,17 @@ from coati.models.opt import OPTRM, OPTActor, OPTCritic from coati.models.utils import calc_action_log_probs, compute_reward, masked_mean -@pytest.mark.gpu @pytest.mark.parametrize("batch_size", [4]) @pytest.mark.parametrize("seq_len", [32]) -@pytest.mark.parametrize("actor_maker", [ - lambda: BLOOMActor(), - lambda: GPTActor(), +@pytest.mark.parametrize( + "actor_maker", + [ + lambda: BLOOMActor(), + lambda: GPTActor(), # HACK: skip llama due to long execution time # lambda: LlamaActor(), - lambda: OPTActor() -]) + lambda: OPTActor() + ]) @pytest.mark.parametrize("generate_kwargs", [{ "max_length": 64, "use_cache": True, @@ -32,23 +33,15 @@ from coati.models.utils import calc_action_log_probs, compute_reward, masked_mea "temperature": 1.0, "top_k": 50, }]) -def test_generation(actor_maker: Callable[[], Actor], - batch_size: int, - seq_len: int, - generate_kwargs: Dict[str, Any] - ): +def test_generation(actor_maker: Callable[[], Actor], batch_size: int, seq_len: int, generate_kwargs: Dict[str, Any]): actor = actor_maker() input_ids = torch.randint(0, 100, (batch_size, seq_len)).cuda() sequences = generate(actor.cuda(), input_ids, **generate_kwargs) assert sequences.shape == (batch_size, generate_kwargs["max_length"]) -@pytest.mark.cpu def test_utils(): - fn_input = { - "tensor": torch.ones((10, )), - "mask": torch.randint(0, 2, (10, )) - } + fn_input = {"tensor": torch.ones((10,)), "mask": torch.randint(0, 2, (10,))} fn_output = masked_mean(dim=0, **fn_input) assert fn_output.dim() == 0 assert torch.allclose(fn_output, torch.tensor(1.0)) @@ -56,14 +49,14 @@ def test_utils(): batch_size = 4 num_labels = 10 fn_input = { - "r": torch.ones((batch_size, )), + "r": torch.ones((batch_size,)), "kl_coef": 1.0, "log_probs": torch.randn((batch_size, num_labels)), "log_probs_base": torch.randn((batch_size, num_labels)), "action_mask": torch.randint(0, 2, (batch_size, num_labels)) } fn_output = compute_reward(**fn_input) - assert fn_output.shape == (batch_size, ) + assert fn_output.shape == (batch_size,) batch_size = 4 seq_len = 32 @@ -80,17 +73,11 @@ def test_utils(): assert fn_output.shape == (batch_size, num_actions) -@pytest.mark.cpu @pytest.mark.parametrize("lora_rank", [4]) @pytest.mark.parametrize("num_dim", [32]) @pytest.mark.parametrize("num_layers", [4]) -def test_lora(lora_rank: int, - num_dim: int, - num_layers: int): - model = nn.ModuleList( - [nn.Linear(num_dim, num_dim) - for _ in range(num_layers)] - ) +def test_lora(lora_rank: int, num_dim: int, num_layers: int): + model = nn.ModuleList([nn.Linear(num_dim, num_dim) for _ in range(num_layers)]) lora_model = convert_to_lora_module(model, lora_rank) assert isinstance(lora_model, nn.ModuleList) for i in range(num_layers): @@ -103,8 +90,7 @@ def test_lora(lora_rank: int, assert isinstance(lora_model[i], LoraLinear) assert torch.allclose(old_model[i].weight, lora_model[i].weight) assert torch.allclose(old_model[i].bias, lora_model[i].bias) - assert torch.allclose(old_model[i].lora_B @ old_model[i].lora_A, - lora_model[i].lora_B @ lora_model[i].lora_A) + assert torch.allclose(old_model[i].lora_B @ old_model[i].lora_A, lora_model[i].lora_B @ lora_model[i].lora_A) optimizer = torch.optim.Adam(lora_model.parameters()) x = torch.randn(8, num_dim) for i in range(num_layers): @@ -120,20 +106,19 @@ def test_lora(lora_rank: int, lora_model[i].lora_B @ lora_model[i].lora_A) -@pytest.mark.cpu @pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("seq_len", [128]) -@pytest.mark.parametrize("models_maker", [ - lambda: (BLOOMActor(), BLOOMCritic(), BLOOMRM()), - lambda: (GPTActor(), GPTCritic(), GPTRM()), +@pytest.mark.parametrize( + "models_maker", + [ + lambda: (BLOOMActor(), BLOOMCritic(), BLOOMRM()), + lambda: (GPTActor(), GPTCritic(), GPTRM()), # HACK: skip llama due to long execution time # lambda: (LlamaActor(), LlamaCritic(), LlamaRM()), - lambda: (OPTActor(), OPTCritic(), OPTRM()), -]) + lambda: (OPTActor(), OPTCritic(), OPTRM()), + ]) @torch.no_grad() -def test_models(models_maker: Callable[[], Tuple[Actor, Critic, RewardModel]], - batch_size: int, - seq_len: int): +def test_models(models_maker: Callable[[], Tuple[Actor, Critic, RewardModel]], batch_size: int, seq_len: int): actor_input = { "input_ids": torch.randint(0, 100, (batch_size, seq_len)), @@ -162,17 +147,14 @@ def test_models(models_maker: Callable[[], Tuple[Actor, Critic, RewardModel]], rm_output = rm(**rm_input) assert actor_output.logits.shape[:2] == (batch_size, seq_len) - assert critic_output.shape == (batch_size, ) - assert rm_output.shape == (batch_size, ) + assert critic_output.shape == (batch_size,) + assert rm_output.shape == (batch_size,) -@pytest.mark.cpu @pytest.mark.parametrize("batch_size", [16]) @pytest.mark.parametrize("seq_len", [128]) @pytest.mark.parametrize("num_labels", [100]) -def test_loss(batch_size: int, - seq_len: int, - num_labels: int): +def test_loss(batch_size: int, seq_len: int, num_labels: int): loss = GPTLMLoss() loss_input = { "logits": torch.randn(batch_size, seq_len, num_labels), @@ -182,54 +164,43 @@ def test_loss(batch_size: int, loss = PolicyLoss() loss_input = { - "log_probs": torch.randn(batch_size, ), - "old_log_probs": torch.randn(batch_size, ), - "advantages": torch.randn(batch_size, ) + "log_probs": torch.randn(batch_size,), + "old_log_probs": torch.randn(batch_size,), + "advantages": torch.randn(batch_size,) } loss_output = loss(**loss_input) loss = ValueLoss() loss_input = { - "values": torch.randn(batch_size, ), - "old_values": torch.randn(batch_size, ), - "reward": torch.randn(batch_size, ) + "values": torch.randn(batch_size,), + "old_values": torch.randn(batch_size,), + "reward": torch.randn(batch_size,) } loss_output = loss(**loss_input) loss = LogSigLoss() loss_input = { - "chosen_reward": torch.randn(batch_size, ), - "reject_reward": torch.randn(batch_size, ), + "chosen_reward": torch.randn(batch_size,), + "reject_reward": torch.randn(batch_size,), } loss_output = loss(**loss_input) loss = LogExpLoss() loss_input = { - "chosen_reward": torch.randn(batch_size, ), - "reject_reward": torch.randn(batch_size, ), + "chosen_reward": torch.randn(batch_size,), + "reject_reward": torch.randn(batch_size,), } loss_output = loss(**loss_input) if __name__ == "__main__": - generate_kwargs = dict(max_length=40, - use_cache=True, - do_sample=True, - temperature=1.0, - top_k=50) - test_generation(lambda: LlamaActor(), - batch_size=4, - seq_len=32, - generate_kwargs=generate_kwargs) + generate_kwargs = dict(max_length=40, use_cache=True, do_sample=True, temperature=1.0, top_k=50) + test_generation(lambda: LlamaActor(), batch_size=4, seq_len=32, generate_kwargs=generate_kwargs) test_utils() test_lora(lora_rank=2, num_dim=8, num_layers=2) - test_models(models_maker=lambda: (BLOOMActor(), - BLOOMCritic(), - BLOOMRM()), - batch_size=8, - seq_len=128) + test_models(models_maker=lambda: (BLOOMActor(), BLOOMCritic(), BLOOMRM()), batch_size=8, seq_len=128) test_loss(batch_size=8, seq_len=128, num_labels=100) diff --git a/pytest.ini b/pytest.ini index e8a60c853..7912dbffc 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,7 +1,5 @@ [pytest] markers = - cpu: tests which can run on CPU - gpu: tests which requires a single GPU - dist: tests which are run in a multi-GPU or multi-machine environment - experiment: tests for experimental features + dist: tests which are run in a multi-GPU or multi-machine environment (at least 4 GPUs) + largedist: tests which are run in a multi-GPU or multi-machine environment (at least 8 GPUs) addopts = --ignore=tests/test_analyzer --ignore=tests/test_auto_parallel --ignore=tests/test_autochunk --ignore=tests/test_moe diff --git a/tests/test_config/test_load_config.py b/tests/test_config/test_load_config.py index 550af2a4a..38b5e3f5f 100644 --- a/tests/test_config/test_load_config.py +++ b/tests/test_config/test_load_config.py @@ -8,7 +8,6 @@ import pytest from colossalai.context.config import Config -@pytest.mark.cpu def test_load_config(): filename = Path(__file__).parent.joinpath('sample_config.py') config = Config.from_file(filename) diff --git a/tests/test_context/test_hybrid_parallel.py b/tests/test_context/test_hybrid_parallel.py index 9f26a5af5..d25668afd 100644 --- a/tests/test_context/test_hybrid_parallel.py +++ b/tests/test_context/test_hybrid_parallel.py @@ -143,7 +143,6 @@ def run_dist(rank, world_size, port, backend, port_list, host): reset_seeds() -@pytest.mark.cpu @rerun_if_address_is_in_use() def test_context(): """ diff --git a/tests/test_data/test_cifar10_dataset.py b/tests/test_data/test_cifar10_dataset.py index 4b9ca61d9..dfa9fa211 100644 --- a/tests/test_data/test_cifar10_dataset.py +++ b/tests/test_data/test_cifar10_dataset.py @@ -5,11 +5,10 @@ import os from pathlib import Path import pytest -from torchvision import transforms, datasets from torch.utils.data import DataLoader +from torchvision import datasets, transforms -@pytest.mark.cpu def test_cifar10_dataset(): # build transform transform_pipeline = [transforms.ToTensor()] diff --git a/tests/test_data/test_data_parallel_sampler.py b/tests/test_data/test_data_parallel_sampler.py index 2ad3fd696..7beef707c 100644 --- a/tests/test_data/test_data_parallel_sampler.py +++ b/tests/test_data/test_data_parallel_sampler.py @@ -53,7 +53,6 @@ def run_data_sampler(rank, world_size, port): torch.cuda.empty_cache() -@pytest.mark.cpu @rerun_if_address_is_in_use() def test_data_sampler(): spawn(run_data_sampler, 4) diff --git a/tests/test_data/test_deterministic_dataloader.py b/tests/test_data/test_deterministic_dataloader.py index 239e79dff..283b5cc35 100644 --- a/tests/test_data/test_deterministic_dataloader.py +++ b/tests/test_data/test_deterministic_dataloader.py @@ -64,7 +64,6 @@ def run_data_sampler(rank, world_size, port): torch.cuda.empty_cache() -@pytest.mark.cpu @rerun_if_address_is_in_use() def test_data_sampler(): spawn(run_data_sampler, 4) diff --git a/tests/test_utils/test_activation_checkpointing.py b/tests/test_utils/test_activation_checkpointing.py index 2930552cc..b7764c2f4 100644 --- a/tests/test_utils/test_activation_checkpointing.py +++ b/tests/test_utils/test_activation_checkpointing.py @@ -40,7 +40,6 @@ def forward_inplace(x, weight): return out -@pytest.mark.gpu @clear_cache_before_run() @parameterize("use_reentrant", [True, False]) @parameterize("cpu_offload", [True, False]) From a78daf6180cec55b37713418cad8f406f57939e8 Mon Sep 17 00:00:00 2001 From: LuGY <74758262+Gy-Lu@users.noreply.github.com> Date: Wed, 16 Aug 2023 19:29:03 +0800 Subject: [PATCH 04/33] [shardformer] support interleaved pipeline (#4448) * support interleaved pipeline * fix unit test * remove virtual stage test in stage mgr * add droped type hint and updated bwd --- colossalai/cluster/process_group_mesh.py | 10 +- colossalai/pipeline/p2p.py | 45 +-- .../pipeline/schedule/interleaved_pp.py | 370 ++++++++++++++++++ colossalai/pipeline/schedule/one_f_one_b.py | 78 +++- colossalai/pipeline/stage_manager.py | 78 +--- .../test_schedule/test_interleaved.py | 161 ++++++++ tests/test_pipeline/test_stage_manager.py | 9 - 7 files changed, 642 insertions(+), 109 deletions(-) create mode 100644 colossalai/pipeline/schedule/interleaved_pp.py create mode 100644 tests/test_pipeline/test_schedule/test_interleaved.py diff --git a/colossalai/cluster/process_group_mesh.py b/colossalai/cluster/process_group_mesh.py index 1dfd261d5..623160003 100644 --- a/colossalai/cluster/process_group_mesh.py +++ b/colossalai/cluster/process_group_mesh.py @@ -94,17 +94,23 @@ class ProcessGroupMesh: return np.unravel_index(rank, shape) @staticmethod - def ravel(coord: Tuple[int, ...], shape: Tuple[int, ...]) -> int: + def ravel(coord: Tuple[int, ...], shape: Tuple[int, ...], mode: str = 'raise') -> int: """Convert a coordinate to a rank. + mode: ['raise', 'wrap', 'clip'], see https://numpy.org/doc/stable/reference/generated/numpy.ravel_multi_index.html. + with wrap, index out of range would be wrapped around. + For instance, ravel((0, i, 0), (1, 2, 1), 'wrap') returns (i % 2) Args: coords (Tuple[int, ...]): Coordinate to be converted. shape (Tuple[int, ...]): Shape of the process group mesh. + mode (Optional[str]): The mode for numpy.ravel_multi_index. Returns: int: Rank of the coordinate. """ - return np.ravel_multi_index(coord, shape) + + assert mode in ["raise", "wrap", "clip"] + return np.ravel_multi_index(coord, shape, mode) def get_group(self, ranks_in_group: List[int], backend: Optional[str] = None) -> ProcessGroup: """Get the process group with the given ranks. It the process group doesn't exist, it will be created. diff --git a/colossalai/pipeline/p2p.py b/colossalai/pipeline/p2p.py index af7a00b5c..aed85cf91 100644 --- a/colossalai/pipeline/p2p.py +++ b/colossalai/pipeline/p2p.py @@ -173,14 +173,10 @@ class PipelineP2PCommunication: Returns: Any: The input tensor or input tensor list. """ - if self.stage_manager.is_first_stage(): - input_tensor = None - else: - if prev_rank is None: - prev_rank = self.stage_manager.get_prev_rank() - cur_rank = self.stage_manager.get_rank() - input_tensor = _recv_object(prev_rank, cur_rank, - self.stage_manager.get_p2p_process_group(prev_rank, cur_rank)) + if prev_rank is None: + prev_rank = self.stage_manager.get_prev_rank() + cur_rank = self.stage_manager.get_rank() + input_tensor = _recv_object(prev_rank, cur_rank, self.stage_manager.get_p2p_process_group(prev_rank, cur_rank)) return input_tensor @@ -193,14 +189,11 @@ class PipelineP2PCommunication: Returns: Any: The input gradient tensor or gradient tensor list. """ - if self.stage_manager.is_last_stage(): - output_tensor_grad = None - else: - if next_rank is None: - next_rank = self.stage_manager.get_next_rank() - cur_rank = self.stage_manager.get_rank() - output_tensor_grad = _recv_object(next_rank, cur_rank, - self.stage_manager.get_p2p_process_group(next_rank, cur_rank)) + if next_rank is None: + next_rank = self.stage_manager.get_next_rank() + cur_rank = self.stage_manager.get_rank() + output_tensor_grad = _recv_object(next_rank, cur_rank, + self.stage_manager.get_p2p_process_group(next_rank, cur_rank)) return output_tensor_grad @@ -211,12 +204,10 @@ class PipelineP2PCommunication: output_object (Any): Object to be sent. next_rank (int, optional): The rank of the recipient of the tensor. """ - if not self.stage_manager.is_last_stage(): - if next_rank is None: - next_rank = self.stage_manager.get_next_rank() - cur_rank = self.stage_manager.get_rank() - _send_object(output_object, cur_rank, next_rank, - self.stage_manager.get_p2p_process_group(cur_rank, next_rank)) + if next_rank is None: + next_rank = self.stage_manager.get_next_rank() + cur_rank = self.stage_manager.get_rank() + _send_object(output_object, cur_rank, next_rank, self.stage_manager.get_p2p_process_group(cur_rank, next_rank)) def send_backward(self, input_object: Any, prev_rank: int = None) -> None: """Sends the gradient tensor to the previous stage in pipeline. @@ -225,9 +216,7 @@ class PipelineP2PCommunication: input_object (Any): Object to be sent. prev_rank (int, optional): The rank of the recipient of the tensor """ - if not self.stage_manager.is_first_stage(): - if prev_rank is None: - prev_rank = self.stage_manager.get_prev_rank() - cur_rank = self.stage_manager.get_rank() - _send_object(input_object, cur_rank, prev_rank, - self.stage_manager.get_p2p_process_group(cur_rank, prev_rank)) + if prev_rank is None: + prev_rank = self.stage_manager.get_prev_rank() + cur_rank = self.stage_manager.get_rank() + _send_object(input_object, cur_rank, prev_rank, self.stage_manager.get_p2p_process_group(cur_rank, prev_rank)) diff --git a/colossalai/pipeline/schedule/interleaved_pp.py b/colossalai/pipeline/schedule/interleaved_pp.py new file mode 100644 index 000000000..35a33491b --- /dev/null +++ b/colossalai/pipeline/schedule/interleaved_pp.py @@ -0,0 +1,370 @@ +from functools import partial +from typing import Any, Callable, Iterable, List, Optional, Union + +import torch +import torch.cuda +from torch.nn import Module +from torch.utils._pytree import tree_map + +from colossalai.interface import OptimizerWrapper +from colossalai.pipeline.p2p import PipelineP2PCommunication +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.utils.cuda import get_current_device + +from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device +from .base import PipelineSchedule + + +class InterleavedSchedule(PipelineSchedule): + + def __init__(self, num_microbatches: int, num_model_chunks: int, stage_manager: PipelineStageManager) -> None: + self.num_model_chunks = num_model_chunks + assert num_microbatches % self.num_model_chunks == 0, \ + "Number of microbatches should be an integer multiple of number of model chunks" + super().__init__(stage_manager) + self.comm = PipelineP2PCommunication(stage_manager) + self.num_microbatches = num_microbatches + self.batch: Optional[Any] = None + self.batch_size: Optional[int] = None + self.microbatch_offset: Optional[int] = None + self.microbatch_size: Optional[int] = None + + def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None: + """Load a batch from data iterator. + + Args: + data_iter (Iterable): Data iterator. + device (Optional[torch.device], optional): Target device. Defaults to None. + """ + batch = next(data_iter) + if device is not None: + batch = tree_map(partial(to_device, device=device), batch) + self.batch = batch + self.batch_size = get_batch_size(batch) + self.microbatch_offset = [0 for _ in range(self.num_model_chunks)] + assert self.batch_size % self.num_microbatches == 0, \ + "Batch size should divided by the number of microbatches" + self.microbatch_size = self.batch_size // self.num_microbatches + + def load_micro_batch(self, model_chunk_id: int) -> Any: + """Load a micro batch from the current batch. + + Args: + microbatch_id (int): the current model chunk idx. + + Returns: + Any: Micro batch. + """ + micro_batch = get_micro_batch(self.batch, self.microbatch_offset[model_chunk_id], self.microbatch_size) + self.microbatch_offset[model_chunk_id] += self.microbatch_size + return tree_map(partial(to_device, device=get_current_device()), micro_batch) + + def get_model_chunk_id(self, microbatch_id: int, forward: bool) -> int: + """Helper method to get the model chunk ID given the iteration number. + + Args: + microbatch_id (int): the current microbatch idx + forward (bool): if is the forward process + + Returns: + int: The model chunk idx of the input microbatch_id + """ + microbatch_id_in_group = (microbatch_id) % (self.stage_manager.num_stages * self.num_model_chunks) + model_chunk_id = microbatch_id_in_group // self.stage_manager.num_stages + if not forward: + model_chunk_id = (self.num_model_chunks - model_chunk_id - 1) + return model_chunk_id + + def is_first_stage(self, model_chunk_id: int) -> bool: + """Is the current virtual stage the first stage + + Args: + model_chunk_id (int): The current model chunk idx. + + Returns: + bool: Whether the current virtual stage is the first stage. + """ + if self.stage_manager.is_first_stage() and model_chunk_id == 0: + return True + return False + + def is_last_stage(self, model_chunk_id: int) -> bool: + """Is the current virtual stage the last stage + + Args: + model_chunk_id (int): The current model chunk idx. + + Returns: + bool: Whether the current virtual stage is the last stage. + """ + if self.stage_manager.is_last_stage() and model_chunk_id == self.num_model_chunks - 1: + return True + return False + + def recv_forward(self, model_chunk_id: int, prev_rank: int = None) -> Any: + """Copy the forward output from the previous stage in pipeline as the input tensor of this stage. + For interleaved 1F1B. + + Args: + model_chunk_id (int): The current model chunk idx. + prev_rank (int, optional): The rank of the source of the tensor. + + Returns: + Any: The input tensor or input tensor list. + """ + if self.is_first_stage(model_chunk_id): + input_tensor = None + else: + input_tensor = self.comm.recv_forward(prev_rank) + + return input_tensor + + def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> Any: + """Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage. + For interleaved 1F1B. + + Args: + model_chunk_id (int): The current model chunk idx. + next_rank (int, optional): The rank of the source of the tensor. + + Returns: + Any: The input gradient tensor or gradient tensor list. + """ + if self.is_last_stage(model_chunk_id): + output_tensor_grad = None + else: + output_tensor_grad = self.comm.recv_backward(next_rank) + + return output_tensor_grad + + def send_forward(self, model_chunk_id, output_object: Any, next_rank: int = None) -> None: + """Sends the input tensor to the next stage in pipeline. + For interleaved 1F1B. + + Args: + model_chunk_id (int): The current model chunk idx. + output_object (Any): Object to be sent. + next_rank (int, optional): The rank of the recipient of the tensor. + """ + if not self.is_last_stage(model_chunk_id): + self.comm.send_forward(output_object, next_rank) + + def send_backward(self, model_chunk_id, input_object: Any, prev_rank: int = None) -> None: + """Sends the gradient tensor to the previous stage in pipeline. + For interleaved 1F1B. + + Args: + model_chunk_id (int): The current model chunk idx. + input_object (Any): Object to be sent. + prev_rank (int, optional): The rank of the recipient of the tensor + """ + if not self.is_first_stage(model_chunk_id): + self.comm.send_backward(input_object, prev_rank) + + def forward_step(self, + model_chunk: Module, + model_chunk_id: int, + input_obj: Optional[dict], + criterion: Callable, + accum_loss: Optional[torch.Tensor] = None, + outputs: Optional[List[Any]] = None) -> Union[torch.Tensor, dict]: + """Forward one step of the pipeline + Args: + model (Module): Model Chunk to be run + input_obj (Optional[dict]): The output from the previous stage. If it is the first stage, the `input_obj` is None. + criterion (Callable): Criterion to calculate loss. + accum_loss (Optional[torch.Tensor], optional): Accumulated loss. Defaults to None. + outputs (Optional[List[Any]], optional): List to store the output of the last stage (final output). Defaults to None. + + Returns: + Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor). + """ + micro_batch = self.load_micro_batch(model_chunk_id=model_chunk_id) + + # for the first stage, input_obj is None + # for the non-first stage, input_obj is the output of the previous stage and it's must be a dict + output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, input_obj) + + if self.is_last_stage(model_chunk_id): + loss = criterion(output_obj, micro_batch) / self.num_microbatches + if accum_loss is not None: + accum_loss.add_(loss.detach()) + if outputs is not None: + outputs.append(tree_map(detach, output_obj)) + return loss + else: + return output_obj + + def backward_step(self, optimizer: OptimizerWrapper, input_obj: Optional[dict], + output_obj: Union[dict, torch.Tensor], output_obj_grad: Optional[dict]) -> Optional[dict]: + """Backward one step of the pipeline + + Args: + optimizer (OptimizerWrapper): Optimizer to update the model + input_obj (Optional[dict]): Output of the previous stage. If it is the first stage, the `input_obj` is None. + output_obj (Union[dict, torch.Tensor]): Output of the current stage. If it is the last stage, the output is the loss (Tensor). + output_obj_grad (dict): Gradient of the `output_obj`. If it is the last stage, the `output_obj_grad` is None. + + Returns: + Optional[dict]: Gradient of the `input_obj`. If it is the first stage, the `input_obj_grad` is None. + """ + + # Retain the grad on the input_obj. + tree_map(retain_grad, input_obj) + + # Backward pass. + if output_obj_grad is None: + optimizer.backward(output_obj) + else: + if "backward_tensor_keys" not in output_obj: + for k, grad in output_obj_grad.items(): + optimizer.backward_by_grad(output_obj[k], grad) + else: + for k, grad in output_obj_grad.items(): + output_obj[k].grad = grad + for k in output_obj["backward_tensor_keys"]: + tensor_to_backward = output_obj[k] + optimizer.backward_by_grad(tensor_to_backward, tensor_to_backward.grad) + + # Collect the grad of the input_obj. + input_obj_grad = None + if input_obj is not None: + input_obj_grad = {} + for k, v in input_obj.items(): + if isinstance(v, torch.Tensor) and v.grad is not None: + input_obj_grad[k] = v.grad + return input_obj_grad + + def forward_backward_step(self, + model_chunk: Module, + optimizer: OptimizerWrapper, + data_iter: Iterable, + criterion: Callable[..., Any], + return_loss: bool = False, + return_outputs: bool = False) -> dict: + """Runs interleaved 1F1B schedule, with communication between pipeline stages. + + Args: + model_chunk (List[Module]): Model Chunk to be trained. + optimizer (OptimizerWrapper): Optimizer to be used. + data_iter (Iterable): Data iterator. + criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor. + return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss. + return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs. + + Returns: + dict: A dict with keys: 'loss' and 'outputs'. + """ + forward_only = not torch.is_grad_enabled() + + self.load_batch(data_iter) + num_model_chunks = len(model_chunk) + + # num_warmup_microbatches is the step when not all the processes are working + num_microbatches = self.num_microbatches * num_model_chunks + if forward_only: + num_warmup_microbatches = num_microbatches + else: + num_warmup_microbatches = (self.stage_manager.num_stages - self.stage_manager.stage - 1) * 2 + num_warmup_microbatches += (num_model_chunks - 1) * self.stage_manager.num_stages + num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches) + + num_microbatches_remaining = num_microbatches - num_warmup_microbatches + + # Input, output tensors only need to be saved when doing backward passes + input_objs = None + output_objs = None + + if not forward_only: + input_objs = [[] for _ in range(num_model_chunks)] + output_objs = [[] for _ in range(num_model_chunks)] + + outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None + + if return_loss and self.stage_manager.is_last_stage(): + accum_loss = torch.zeros(1, device=get_current_device()) + else: + accum_loss = None + + # for ranks except the first one, get into recv state + # print(self.stage_manager.stage,num_microbatches, num_warmup_microbatches, num_microbatches_remaining) + input_obj = self.recv_forward(0) + input_objs[0].append(input_obj) + # Run warmup forward passes. + for i in range(num_warmup_microbatches): + model_chunk_id = self.get_model_chunk_id(i, forward=True) + + # recv first on first rank to avoid sending or recving at the same time + if self.stage_manager.is_first_stage(): + input_obj = self.recv_forward(model_chunk_id) + output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs) + self.send_forward(model_chunk_id, output_obj) + if not forward_only: + input_objs[model_chunk_id].append(input_obj) + output_objs[model_chunk_id].append(output_obj) + else: + output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs) + if not forward_only: + output_objs[model_chunk_id].append(output_obj) + self.send_forward(model_chunk_id, output_obj) + if num_microbatches_remaining == 0 and i + 1 == num_warmup_microbatches: + break + else: + model_chunk_id = self.get_model_chunk_id(i + 1, forward=True) + + input_obj = self.recv_forward(model_chunk_id) + if not forward_only: + input_objs[model_chunk_id].append(input_obj) + + # Run 1F1B in steady state. + for i in range(num_microbatches_remaining): + model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatches, forward=True) + last_iteration = (i == (num_microbatches_remaining - 1)) + + output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs) + if forward_only: + self.send_forward(model_chunk_id, output_obj) + + if not last_iteration: + input_obj = self.recv_forward(model_chunk_id) + + else: + self.send_forward(model_chunk_id, output_obj) + # Add input_obj and output_obj to end of list. + input_objs[model_chunk_id].append(input_obj) + output_objs[model_chunk_id].append(output_obj) + + model_chunk_id = self.get_model_chunk_id(i, forward=False) + output_obj_grad = self.recv_backward(model_chunk_id) + + # Pop output_obj and output_obj from the start of the list for + # the backward pass. + input_obj = input_objs[model_chunk_id].pop(0) + output_obj = output_objs[model_chunk_id].pop(0) + + # backward + input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad) + + if last_iteration: + input_obj = None + else: + model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatches + 1, forward=True) + input_obj = self.recv_forward(model_chunk_id) + model_chunk_id = self.get_model_chunk_id(i, forward=False) + self.send_backward(model_chunk_id, input_obj_grad) + + # Run cooldown backward passes. + if not forward_only: + for i in range(num_microbatches_remaining, num_microbatches): + model_chunk_id = self.get_model_chunk_id(i, forward=False) + # print(f"{self.stage_manager.stage}/{model_chunk_id}: {len(input_objs[model_chunk_id])} {len(output_objs[model_chunk_id])} {i}") + input_obj = input_objs[model_chunk_id].pop(0) + output_obj = output_objs[model_chunk_id].pop(0) + + output_obj_grad = self.recv_backward(model_chunk_id) + input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad) + self.send_backward(model_chunk_id, input_obj_grad) + + if outputs is not None: + outputs = merge_batch(outputs) + return {'loss': accum_loss, 'outputs': outputs} diff --git a/colossalai/pipeline/schedule/one_f_one_b.py b/colossalai/pipeline/schedule/one_f_one_b.py index ade3cf456..f5e4929aa 100644 --- a/colossalai/pipeline/schedule/one_f_one_b.py +++ b/colossalai/pipeline/schedule/one_f_one_b.py @@ -53,6 +53,62 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): self.microbatch_offset += self.microbatch_size return tree_map(partial(to_device, device=get_current_device()), micro_batch) + def recv_forward(self, prev_rank: int = None) -> Any: + """Copy the forward output from the previous stage in pipeline as the input tensor of this stage. + For 1F1B. + + Args: + prev_rank (int, optional): The rank of the source of the tensor. + + Returns: + Any: The input tensor or input tensor list. + """ + if self.stage_manager.is_first_stage(): + input_tensor = None + else: + input_tensor = self.comm.recv_forward(prev_rank) + + return input_tensor + + def recv_backward(self, next_rank: int = None) -> Any: + """Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage. + For 1F1B. + + Args: + next_rank (int, optional): The rank of the source of the tensor. + + Returns: + Any: The input gradient tensor or gradient tensor list. + """ + if self.stage_manager.is_last_stage(): + output_tensor_grad = None + else: + output_tensor_grad = self.comm.recv_backward(next_rank) + + return output_tensor_grad + + def send_forward(self, output_object: Any, next_rank: int = None) -> None: + """Sends the input tensor to the next stage in pipeline. + For 1F1B. + + Args: + output_object (Any): Object to be sent. + next_rank (int, optional): The rank of the recipient of the tensor. + """ + if not self.stage_manager.is_last_stage(): + self.comm.send_forward(output_object, next_rank) + + def send_backward(self, input_object: Any, prev_rank: int = None) -> None: + """Sends the gradient tensor to the previous stage in pipeline. + For 1F1B. + + Args: + input_object (Any): Object to be sent. + prev_rank (int, optional): The rank of the recipient of the tensor + """ + if not self.stage_manager.is_first_stage(): + self.comm.send_backward(input_object, prev_rank) + def forward_step(self, model: Module, input_obj: Optional[dict], @@ -171,11 +227,11 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): # Run warmup forward passes. for i in range(num_warmup_microbatches): - input_obj = self.comm.recv_forward() + input_obj = self.recv_forward() output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs) - self.comm.send_forward(output_obj) + self.send_forward(output_obj) if not forward_only: input_objs.append(input_obj) @@ -185,7 +241,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): # If all microbatches are run in warmup / cooldown phase, then no need to # receive this tensor here. if num_microbatches_remaining > 0: - input_obj = self.comm.recv_forward() + input_obj = self.recv_forward() # Run 1F1B in steady state. for i in range(num_microbatches_remaining): @@ -193,15 +249,15 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs) if forward_only: - self.comm.send_forward(output_obj) + self.send_forward(output_obj) if not last_iteration: - input_obj = self.comm.recv_forward() + input_obj = self.recv_forward() else: # TODO adjust here - self.comm.send_forward(output_obj) - output_obj_grad = self.comm.recv_backward() + self.send_forward(output_obj) + output_obj_grad = self.recv_backward() # Add input_obj and output_obj to end of list. input_objs.append(input_obj) @@ -216,8 +272,8 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): if last_iteration: input_obj = None else: - input_obj = self.comm.recv_forward() - self.comm.send_backward(input_obj_grad) + input_obj = self.recv_forward() + self.send_backward(input_obj_grad) # Run cooldown backward passes. if not forward_only: @@ -225,9 +281,9 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): input_obj = input_objs.pop(0) output_obj = output_objs.pop(0) - output_obj_grad = self.comm.recv_backward() + output_obj_grad = self.recv_backward() input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad) - self.comm.send_backward(input_obj_grad) + self.send_backward(input_obj_grad) if outputs is not None: outputs = merge_batch(outputs) diff --git a/colossalai/pipeline/stage_manager.py b/colossalai/pipeline/stage_manager.py index fe228e227..6ba7dc629 100644 --- a/colossalai/pipeline/stage_manager.py +++ b/colossalai/pipeline/stage_manager.py @@ -17,28 +17,24 @@ class PipelineStageManager: Attributes: num_stages (int): Number of stages in the pipeline. stage (int): The current stage. - num_virtual_stages (int): Number of virtual stages in the pipeline. - virtual_stage (int): The current virtual stage. """ - def __init__(self, pg_mesh: ProcessGroupMesh, pipeline_axis: int) -> None: + def __init__(self, pg_mesh: ProcessGroupMesh, pipeline_axis: int, is_virtual: bool = False) -> None: self.pg_mesh = pg_mesh self.pipeline_axis = pipeline_axis - self.num_virtual_stages: Optional[int] = None - self.virtual_stage: Optional[int] = None self.prev_rank: Optional[Tuple[int, ...]] = None self.next_rank: Optional[Tuple[int, ...]] = None self.p2p_groups: Dict[Tuple[int, int], ProcessGroup] = {} # init prev and next coord coord = self.pg_mesh.coordinate() - if self.stage > 0: - prev_coord = coord[: self.pipeline_axis] + \ - (coord[self.pipeline_axis] - 1,) + coord[self.pipeline_axis + 1:] - self.prev_rank = self.pg_mesh.ravel(prev_coord, self.pg_mesh.shape) - if self.stage < self.num_stages - 1: - next_coord = coord[: self.pipeline_axis] + \ - (coord[self.pipeline_axis] + 1,) + coord[self.pipeline_axis + 1:] - self.next_rank = self.pg_mesh.ravel(next_coord, self.pg_mesh.shape) + # the prev rank of rank0 is the last rank + prev_coord = coord[: self.pipeline_axis] + \ + (coord[self.pipeline_axis] - 1,) + coord[self.pipeline_axis + 1:] + self.prev_rank = self.pg_mesh.ravel(prev_coord, self.pg_mesh.shape, mode='wrap') + # the next rank of the last rank is rank0 + next_coord = coord[: self.pipeline_axis] + \ + (coord[self.pipeline_axis] + 1,) + coord[self.pipeline_axis + 1:] + self.next_rank = self.pg_mesh.ravel(next_coord, self.pg_mesh.shape, mode='wrap') # init p2p process groups stages = list(range(self.num_stages)) @@ -48,32 +44,28 @@ class PipelineStageManager: ranks_in_group = self.pg_mesh.get_ranks_in_group(group) self.p2p_groups[tuple(ranks_in_group)] = group - def is_first_stage(self, virtual: bool = False) -> bool: - """Is the current stage the first stage. + if is_virtual: + # add the process group of the first rank and the last rank + # only used in interleaved pipeline for now + group = self.pg_mesh.get_group_along_axis(self.pipeline_axis, [stages[0], stages[-1]]) + if self.stage in [stages[0], stages[-1]]: + ranks_in_group = self.pg_mesh.get_ranks_in_group(group) + self.p2p_groups[tuple(ranks_in_group)] = group - Args: - virtual (bool, optional): Whether to consider virtual stages. Defaults to False. + def is_first_stage(self) -> bool: + """Is the current stage the first stage. Returns: bool: Whether the current stage is the first stage. """ - if virtual: - assert self.num_virtual_stages is not None - return self.virtual_stage == 0 return self.stage == 0 - def is_last_stage(self, virtual: bool = False) -> bool: + def is_last_stage(self) -> bool: """Is the current stage the last stage. - Args: - virtual (bool, optional): Whether to consider virtual stages. Defaults to False. - Returns: bool: Whether the current stage is the last stage. """ - if virtual: - assert self.num_virtual_stages is not None - return self.virtual_stage == self.num_virtual_stages - 1 return self.stage == self.num_stages - 1 @property @@ -108,7 +100,6 @@ class PipelineStageManager: Returns: int: Rank of the previous stage. """ - assert not self.is_first_stage(), "Cannot get previous rank in the first stage." return self.prev_rank def get_next_rank(self) -> int: @@ -117,39 +108,8 @@ class PipelineStageManager: Returns: int: Rank of the next stage. """ - assert not self.is_last_stage(), "Cannot get next rank in the last stage." return self.next_rank - def set_num_virtual_stages(self, num_virtual_stages: int) -> None: - """Set the number of virtual stages. - - Args: - num_virtual_stages (int): Number of virtual stages. - """ - self.num_virtual_stages = num_virtual_stages - - def set_virtual_stage(self, virtual_stage: int) -> None: - """Set the virtual stage. - - Args: - virtual_stage (int): Virtual stage. - """ - self.virtual_stage = virtual_stage - - @contextmanager - def switch_virtual_stage(self, virtual_stage: int) -> None: - """A context manager to switch virtual stage. - - Args: - virtual_stage (int): Target virtual stage. - """ - old_stage = self.virtual_stage - try: - self.set_virtual_stage(virtual_stage) - yield - finally: - self.set_virtual_stage(old_stage) - def get_p2p_process_group(self, first_rank: int, second_rank: int) -> ProcessGroup: """Get the p2p process group between two ranks. The order of the two ranks does not matter. diff --git a/tests/test_pipeline/test_schedule/test_interleaved.py b/tests/test_pipeline/test_schedule/test_interleaved.py new file mode 100644 index 000000000..2ac31c8ca --- /dev/null +++ b/tests/test_pipeline/test_schedule/test_interleaved.py @@ -0,0 +1,161 @@ +import copy +from functools import partial +from types import MethodType + +import pytest +import torch +import torch.nn as nn + +import colossalai +from colossalai.cluster import ProcessGroupMesh +from colossalai.interface import OptimizerWrapper +from colossalai.pipeline.schedule.interleaved_pp import InterleavedSchedule +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.testing.random import seed_all + + +class MlpModel(nn.Module): + + def __init__(self): + super(MlpModel, self).__init__() + self.linear1 = nn.Linear(4, 8) + self.linear2 = nn.Linear(8, 8) + self.linear3 = nn.Linear(8, 8) + self.linear4 = nn.Linear(8, 8) + self.linear5 = nn.Linear(8, 8) + self.linear6 = nn.Linear(8, 8) + self.linear7 = nn.Linear(8, 8) + self.linear8 = nn.Linear(8, 4) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + x = self.linear3(x) + x = self.linear4(x) + x = self.linear5(x) + x = self.linear6(x) + x = self.linear7(x) + x = self.linear8(x) + return x + + +def pp_linear_fwd(forward, + data: torch.Tensor = None, + input_obj: torch.Tensor = None, + stage_mgr: PipelineStageManager = None, + num_chunks: int = None, + model_chunk_id: int = None): + + if stage_mgr.is_first_stage() and model_chunk_id == 0: + return {'input_obj': forward(data)} + elif stage_mgr.is_last_stage() and model_chunk_id == num_chunks - 1: + return forward(input_obj) + else: + return {'input_obj': forward(input_obj)} + + +@parameterize("num_micro_batches", [4, 8, 12]) +def examine_pp(num_micro_batches): + """ + This test is to examine the correctness of interleaved 1F1B, compared with torch. + Be aware it contains some hardcodes. + """ + world_size = torch.distributed.get_world_size() + local_rank = torch.distributed.get_rank() + seed_all(1453) + + NUM_MICRO_BATCHS = num_micro_batches + BATCH_SIZE = num_micro_batches + NUM_CHUNKS = 2 + + # create model + torch_model = MlpModel().cuda() + + pp_model = copy.deepcopy(torch_model).cuda() + + DP_DIM, PP_DIM, TP_DIM = 0, 1, 2 + pg_mesh = ProcessGroupMesh(1, world_size, 1) + stage_manager = PipelineStageManager(pg_mesh, PP_DIM, is_virtual=True) + schedule = InterleavedSchedule(NUM_MICRO_BATCHS, NUM_CHUNKS, stage_manager) + + sharded_model = torch.nn.ModuleList() + for idx, (_, sub_model) in enumerate(pp_model.named_children()): + if idx % (world_size) == local_rank: + sub_model._forward = sub_model.forward + sub_model.forward = MethodType( + partial(pp_linear_fwd, + stage_mgr=stage_manager, + num_chunks=NUM_CHUNKS, + model_chunk_id=len(sharded_model)), sub_model._forward) + sharded_model.append(sub_model.cuda()) + + # create optimizer + torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1) + pp_optimizer = OptimizerWrapper(torch.optim.SGD(sharded_model.parameters(), lr=1)) + + # create + seed_all(1453) + if local_rank == 0: + input_list = [torch.rand(BATCH_SIZE, 4).cuda()] + else: + input_list = [torch.zeros(BATCH_SIZE, 4).cuda()] + torch.distributed.all_reduce(input_list[0]) + + criterion = lambda x, y: torch.mean(x) + + # forward and backward + torch_output = torch_model(input_list[0]) + torch_loss = criterion(torch_output, _) + torch_loss.backward() + + pp_ret = schedule.forward_backward_step(sharded_model, + pp_optimizer, + iter(input_list), + criterion, + return_loss=True, + return_outputs=True) + + # check loss + if stage_manager.is_last_stage(): + assert torch.allclose(torch_loss, pp_ret['loss']) + + # check gradients + torch_grad = [] + for torch_p in torch_model.parameters(): + torch_grad.append(torch_p.grad.data) + + for idx, pp_p in enumerate(sharded_model.parameters()): + if idx < 2: + assert torch.allclose(torch_grad[idx + local_rank * 2], pp_p.grad.data) + else: + assert torch.allclose(torch_grad[idx + local_rank * 2 + 6], pp_p.grad.data) + + # step + torch_optimizer.step() + pp_optimizer.step() + + # check updated param + torch_param = [] + for torch_p in torch_model.parameters(): + torch_param.append(torch_p.data) + for idx, pp_p in enumerate(sharded_model.parameters()): + if idx < 2: + assert torch.allclose(torch_param[idx + local_rank * 2], pp_p.data) + else: + assert torch.allclose(torch_param[idx + local_rank * 2 + 6], pp_p.data) + + +def run_dist(rank, world_size, port): + colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') + examine_pp() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_pp(): + spawn(run_dist, 4) + + +if __name__ == '__main__': + test_pp() diff --git a/tests/test_pipeline/test_stage_manager.py b/tests/test_pipeline/test_stage_manager.py index be4591d58..6e0cd1998 100644 --- a/tests/test_pipeline/test_stage_manager.py +++ b/tests/test_pipeline/test_stage_manager.py @@ -49,15 +49,6 @@ def check_stage_manager(): next_rank = ranks_in_group[ranks_in_group.index(rank) + 1] assert stage_manager.get_next_rank() == next_rank - # check virtual stage - stage_manager.set_num_virtual_stages(PP_SIZE * 2) - assert stage_manager.num_virtual_stages == PP_SIZE * 2 - stage_manager.set_virtual_stage(stage_manager.stage * 2) - assert stage_manager.virtual_stage == stage_manager.stage * 2 - with stage_manager.switch_virtual_stage(stage_manager.stage * 2 + 1): - assert stage_manager.virtual_stage == stage_manager.stage * 2 + 1 - assert stage_manager.virtual_stage == stage_manager.stage * 2 - # check p2p groups for prev, cur in zip(ranks_in_group[:-1], ranks_in_group[1:]): if rank in [prev, cur]: From 7c8be770810835544e2652c6d053e77db83a0949 Mon Sep 17 00:00:00 2001 From: Bin Jia <45593998+FoolPlayer@users.noreply.github.com> Date: Fri, 18 Aug 2023 11:21:53 +0800 Subject: [PATCH 05/33] [shardformer/sequence parallel] support gpt2 seq parallel with pp/dp/tp (#4460) * support gpt2 seq parallel with pp/dp/tp * fix a bug when waiting for stream done * delete unused gpt2_seq file --- .../booster/plugin/hybrid_parallel_plugin.py | 4 + colossalai/shardformer/layer/_operation.py | 2 + colossalai/shardformer/modeling/gpt2.py | 256 +++++++++++++++++- colossalai/shardformer/modeling/gpt2_seq.py | 222 --------------- colossalai/shardformer/policies/gpt2.py | 14 +- .../test_model/test_shard_gpt2.py | 10 +- 6 files changed, 268 insertions(+), 240 deletions(-) delete mode 100644 colossalai/shardformer/modeling/gpt2_seq.py diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 00c714fe4..155f72dc6 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -235,6 +235,10 @@ class HybridParallelPlugin(PipelinePluginBase): assert dist.get_world_size() % ( tp_size * pp_size ) == 0, f'world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}' + + if enable_sequence_parallelism: + assert tp_size > 1, 'Sequence parallelism must be enabled when using tensor parallelism' + # TODO(ver217): support zero assert zero_stage == 0, 'zero is not support yet' self.tp_size = tp_size diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index 13e563123..fc13aca79 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -239,6 +239,7 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function): output = torch.empty(input_.shape, dtype=input_.dtype, device=input_.device).contiguous() torch.cuda.current_stream().wait_stream(calculate_stream) + gather_handle.wait() reducescatter_handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True) with torch.cuda.stream(calculate_stream): @@ -249,6 +250,7 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function): grad_weight = grad_output.t().matmul(input_parallel) torch.cuda.current_stream().wait_stream(calculate_stream) + reducescatter_handle.wait() return output, grad_weight, grad_bias, None, None, None, None diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index 47835d5d5..722f0f523 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -21,6 +21,8 @@ from transformers.models.gpt2.modeling_gpt2 import ( from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward +from colossalai.shardformer.shard import ShardConfig class GPT2PipelineForwards: @@ -47,7 +49,8 @@ class GPT2PipelineForwards: return_dict: Optional[bool] = None, stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None) -> Union[Dict, Tuple, BaseModelOutputWithPastAndCrossAttentions]: + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None) -> Union[Dict, Tuple, BaseModelOutputWithPastAndCrossAttentions]: # This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2Model.forward. # Please refer to original code of transformers for more details. @@ -159,6 +162,13 @@ class GPT2PipelineForwards: all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None all_hidden_states = () if output_hidden_states else None + # split the input tensor along sequence dimension + # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] + if shard_config.enable_sequence_parallelism: + hidden_states = split_forward_gather_backward(hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group) + # Going through held blocks. start_idx, end_idx = stage_index[0], stage_index[1] for i in range(start_idx, end_idx): @@ -212,6 +222,12 @@ class GPT2PipelineForwards: if self.config.add_cross_attention: all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) + # When sequence parallelism done, gather the output tensor in forward and split it in backward + if shard_config.enable_sequence_parallelism: + hidden_states = gather_forward_split_backward(hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group) + if stage_manager.is_last_stage(): hidden_states = self.ln_f(hidden_states) @@ -257,7 +273,8 @@ class GPT2PipelineForwards: return_dict: Optional[bool] = None, stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None) -> Union[Dict, Tuple, CausalLMOutputWithCrossAttentions]: + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None) -> Union[Dict, Tuple, CausalLMOutputWithCrossAttentions]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set @@ -285,7 +302,8 @@ class GPT2PipelineForwards: return_dict=return_dict, stage_manager=stage_manager, hidden_states=hidden_states, - stage_index=stage_index) + stage_index=stage_index, + shard_config=shard_config) # If not at the last stage, return hidden_states as in GPT2Model if not stage_manager.is_last_stage(): @@ -335,7 +353,8 @@ class GPT2PipelineForwards: return_dict: Optional[bool] = None, stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None) -> Union[Dict, Tuple, GPT2DoubleHeadsModelOutput]: + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None) -> Union[Dict, Tuple, GPT2DoubleHeadsModelOutput]: r""" mc_token_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input): Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) - @@ -367,7 +386,8 @@ class GPT2PipelineForwards: return_dict=return_dict, stage_manager=stage_manager, hidden_states=hidden_states, - stage_index=stage_index) + stage_index=stage_index, + shard_config=shard_config) # If not at the last stage, return hidden_states as in GPT2Model if not stage_manager.is_last_stage(): @@ -421,7 +441,8 @@ class GPT2PipelineForwards: return_dict: Optional[bool] = None, stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None) -> Union[Dict, Tuple, QuestionAnsweringModelOutput]: + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None) -> Union[Dict, Tuple, QuestionAnsweringModelOutput]: r""" start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for position (index) of the start of the labelled span for computing the token classification loss. @@ -449,7 +470,8 @@ class GPT2PipelineForwards: return_dict=return_dict, stage_manager=stage_manager, hidden_states=hidden_states, - stage_index=stage_index) + stage_index=stage_index, + shard_config=shard_config) # If not at the last stage, return hidden_states as in GPT2Model if not stage_manager.is_last_stage(): @@ -508,7 +530,8 @@ class GPT2PipelineForwards: return_dict: Optional[bool] = None, stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None) -> Union[Dict, Tuple, TokenClassifierOutput]: + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None) -> Union[Dict, Tuple, TokenClassifierOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., @@ -534,7 +557,8 @@ class GPT2PipelineForwards: return_dict=return_dict, stage_manager=stage_manager, hidden_states=hidden_states, - stage_index=stage_index) + stage_index=stage_index, + shard_config=shard_config) # If not at the last stage, return hidden_states as in GPT2Model if not stage_manager.is_last_stage(): @@ -578,7 +602,8 @@ class GPT2PipelineForwards: return_dict: Optional[bool] = None, stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None) -> Union[Dict, Tuple, SequenceClassifierOutputWithPast]: + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None) -> Union[Dict, Tuple, SequenceClassifierOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., @@ -613,7 +638,8 @@ class GPT2PipelineForwards: return_dict=return_dict, stage_manager=stage_manager, hidden_states=hidden_states, - stage_index=stage_index) + stage_index=stage_index, + shard_config=shard_config) # If not at the last stage, return hidden_states as in GPT2Model if not stage_manager.is_last_stage(): @@ -696,7 +722,6 @@ def get_gpt2_flash_attention_forward(): output_attentions: Optional[bool] = False, ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: _, tgt_len, _ = hidden_states.size() - assert tgt_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4." if encoder_hidden_states is not None: if not hasattr(self, "q_attn"): @@ -753,3 +778,210 @@ def get_gpt2_flash_attention_forward(): return outputs return forward + + +def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig): + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + batch_size = input_ids.shape[0] + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size = inputs_embeds.shape[0] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, input_shape[-1]) + if position_ids is not None: + position_ids = position_ids.view(-1, input_shape[-1]) + + if past_key_values is None: + past_length = 0 + past_key_values = tuple([None] * len(self.h)) + else: + past_length = past_key_values[0][0].size(-2) + if position_ids is None: + position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) + + # GPT2Attention mask. + if attention_mask is not None: + if batch_size <= 0: + raise ValueError("batch_size has to be defined and > 0") + attention_mask = attention_mask.view(batch_size, -1) + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + attention_mask = attention_mask[:, None, None, :] + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and the dtype's smallest value for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.add_cross_attention and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # head_mask has shape n_layer x batch x n_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) + position_embeds = self.wpe(position_ids) + hidden_states = inputs_embeds + position_embeds + + if token_type_ids is not None: + token_type_embeds = self.wte(token_type_ids) + hidden_states = hidden_states + token_type_embeds + + hidden_states = self.drop(hidden_states) + + output_shape = input_shape + (hidden_states.size(-1),) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + all_hidden_states = () if output_hidden_states else None + + # split the input tensor along sequence dimension + # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] + hidden_states = split_forward_gather_backward(hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group) + + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + # Model parallel + if self.model_parallel: + torch.cuda.set_device(hidden_states.device) + # Ensure layer_past is on same device as hidden_states (might not be correct) + if layer_past is not None: + layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) + # Ensure that attention_mask is always on the same device as hidden_states + if attention_mask is not None: + attention_mask = attention_mask.to(hidden_states.device) + if isinstance(head_mask, torch.Tensor): + head_mask = head_mask.to(hidden_states.device) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, use_cache, output_attentions) + + return custom_forward + + outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + None, + attention_mask, + head_mask[i], + encoder_hidden_states, + encoder_attention_mask, + ) + else: + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask[i], + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) + + # Model Parallel: If it's the last layer for that device, put things on the next device + if self.model_parallel: + for k, v in self.device_map.items(): + if i == v[-1] and "cuda:" + str(k) != self.last_device: + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + + # When sequence parallelism done, gather the output tensor in forward and split it in backward + hidden_states = gather_forward_split_backward(hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group) + + hidden_states = self.ln_f(hidden_states) + hidden_states = hidden_states.view(output_shape) + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] + if v is not None) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + return forward diff --git a/colossalai/shardformer/modeling/gpt2_seq.py b/colossalai/shardformer/modeling/gpt2_seq.py deleted file mode 100644 index a6da96e7b..000000000 --- a/colossalai/shardformer/modeling/gpt2_seq.py +++ /dev/null @@ -1,222 +0,0 @@ -# this code is modified from transformers.models.gpt2.modeling_gpt2 -# https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/models/gpt2/modeling_gpt2.py#L670 - -from typing import Optional, Tuple, Union - -import torch -import torch.distributed as dist -from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions -from transformers.utils import logging - -from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward -from colossalai.shardformer.shard import ShardConfig - -logger = logging.get_logger(__name__) - - -# TODO: put all contents in `gpt2.py` and make it compatible with pipeline -def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig): - - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - token_type_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else self.config.output_hidden_states) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - input_shape = input_ids.size() - input_ids = input_ids.view(-1, input_shape[-1]) - batch_size = input_ids.shape[0] - elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] - batch_size = inputs_embeds.shape[0] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - device = input_ids.device if input_ids is not None else inputs_embeds.device - - if token_type_ids is not None: - token_type_ids = token_type_ids.view(-1, input_shape[-1]) - if position_ids is not None: - position_ids = position_ids.view(-1, input_shape[-1]) - - if past_key_values is None: - past_length = 0 - past_key_values = tuple([None] * len(self.h)) - else: - past_length = past_key_values[0][0].size(-2) - if position_ids is None: - position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) - position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) - - # GPT2Attention mask. - if attention_mask is not None: - if batch_size <= 0: - raise ValueError("batch_size has to be defined and > 0") - attention_mask = attention_mask.view(batch_size, -1) - # We create a 3D attention mask from a 2D tensor mask. - # Sizes are [batch_size, 1, 1, to_seq_length] - # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] - # this attention mask is more simple than the triangular masking of causal attention - # used in OpenAI GPT, we just need to prepare the broadcast dimension here. - attention_mask = attention_mask[:, None, None, :] - - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for - # masked positions, this operation will create a tensor which is 0.0 for - # positions we want to attend and the dtype's smallest value for masked positions. - # Since we are adding it to the raw scores before the softmax, this is - # effectively the same as removing these entirely. - attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility - attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min - - # If a 2D or 3D attention mask is provided for the cross-attention - # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] - if self.config.add_cross_attention and encoder_hidden_states is not None: - encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() - encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) - if encoder_attention_mask is None: - encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) - encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) - else: - encoder_attention_mask = None - - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # head_mask has shape n_layer x batch x n_heads x N x N - head_mask = self.get_head_mask(head_mask, self.config.n_layer) - - if inputs_embeds is None: - inputs_embeds = self.wte(input_ids) - position_embeds = self.wpe(position_ids) - hidden_states = inputs_embeds + position_embeds - - if token_type_ids is not None: - token_type_embeds = self.wte(token_type_ids) - hidden_states = hidden_states + token_type_embeds - - hidden_states = self.drop(hidden_states) - - output_shape = input_shape + (hidden_states.size(-1),) - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") - use_cache = False - - presents = () if use_cache else None - all_self_attentions = () if output_attentions else None - all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None - all_hidden_states = () if output_hidden_states else None - - # split the input tensor along sequence dimension - # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] - hidden_states = split_forward_gather_backward(hidden_states, - dim=1, - process_group=shard_config.tensor_parallel_process_group) - - for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): - # Model parallel - if self.model_parallel: - torch.cuda.set_device(hidden_states.device) - # Ensure layer_past is on same device as hidden_states (might not be correct) - if layer_past is not None: - layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) - # Ensure that attention_mask is always on the same device as hidden_states - if attention_mask is not None: - attention_mask = attention_mask.to(hidden_states.device) - if isinstance(head_mask, torch.Tensor): - head_mask = head_mask.to(hidden_states.device) - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache, output_attentions) - - return custom_forward - - outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - hidden_states, - None, - attention_mask, - head_mask[i], - encoder_hidden_states, - encoder_attention_mask, - ) - else: - outputs = block( - hidden_states, - layer_past=layer_past, - attention_mask=attention_mask, - head_mask=head_mask[i], - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - use_cache=use_cache, - output_attentions=output_attentions, - ) - - hidden_states = outputs[0] - if use_cache is True: - presents = presents + (outputs[1],) - - if output_attentions: - all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) - if self.config.add_cross_attention: - all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) - - # Model Parallel: If it's the last layer for that device, put things on the next device - if self.model_parallel: - for k, v in self.device_map.items(): - if i == v[-1] and "cuda:" + str(k) != self.last_device: - hidden_states = hidden_states.to("cuda:" + str(k + 1)) - - # When sequence parallelism done, gather the output tensor in forward and split it in backward - hidden_states = gather_forward_split_backward(hidden_states, - dim=1, - process_group=shard_config.tensor_parallel_process_group) - - hidden_states = self.ln_f(hidden_states) - hidden_states = hidden_states.view(output_shape) - # Add last hidden state - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple( - v for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] - if v is not None) - - return BaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=presents, - hidden_states=all_hidden_states, - attentions=all_self_attentions, - cross_attentions=all_cross_attentions, - ) - - return forward diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 276d95660..d34c0ae9f 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -6,8 +6,7 @@ from torch import Tensor, nn import colossalai.shardformer.layer as col_nn from .._utils import getattr_, setattr_ -from ..modeling.gpt2 import GPT2PipelineForwards, get_gpt2_flash_attention_forward -from ..modeling.gpt2_seq import gpt2_sequence_parallel_forward_fn +from ..modeling.gpt2 import GPT2PipelineForwards, get_gpt2_flash_attention_forward, gpt2_sequence_parallel_forward_fn from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = [ @@ -50,8 +49,6 @@ class GPT2Policy(Policy): target_module=col_nn.DropoutForParallelInput, ), ]) - if self.shard_config.enable_sequence_parallelism: - policy[GPT2Model].method_replacement = {"forward": gpt2_sequence_parallel_forward_fn(self.shard_config)} policy[GPT2Block] = ModulePolicyDescription(attribute_replacement={ "attn.embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, @@ -126,6 +123,7 @@ class GPT2Policy(Policy): }) if self.shard_config.enable_sequence_parallelism: + policy[GPT2Model].method_replacement = {"forward": gpt2_sequence_parallel_forward_fn(self.shard_config)} suffix_list = ["attn.c_attn", "attn.c_proj", "mlp.c_fc", "mlp.c_proj"] self.append_seq_parallel_to_policy(suffix_list=suffix_list, module_policy_description=policy[GPT2Block]) @@ -169,7 +167,13 @@ class GPT2Policy(Policy): layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages) stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) - method_replacement = {'forward': partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} + method_replacement = { + 'forward': + partial(new_forward, + stage_manager=stage_manager, + stage_index=stage_index, + shard_config=self.shard_config) + } self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index 97295f72f..0e29f1dd9 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -105,10 +105,18 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'enable_all_optimization': True, 'use_lazy_init': False, 'precision': 'fp32', +}, { + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 4, + 'enable_all_optimization': True, + 'use_lazy_init': True, + 'enable_sequence_parallelism': True, + 'precision': 'fp32', }, { 'tp_size': 4, 'pp_size': 1, - 'enable_all_optimization': False, + 'enable_all_optimization': True, 'use_lazy_init': True, 'enable_sequence_parallelism': True, 'precision': 'fp32', From 0ecd71e041e517808097f09eafc97811f93d4235 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Fri, 18 Aug 2023 15:34:18 +0800 Subject: [PATCH 06/33] [shardformer] bloom support sequence parallel (#4465) [shardformer] bloom support sequence parallel --- colossalai/shardformer/modeling/bloom.py | 184 ++++++++++++++++++- colossalai/shardformer/policies/bloom.py | 24 ++- colossalai/shardformer/shard/shard_config.py | 1 + 3 files changed, 201 insertions(+), 8 deletions(-) diff --git a/colossalai/shardformer/modeling/bloom.py b/colossalai/shardformer/modeling/bloom.py index 12276635e..66f24dc60 100644 --- a/colossalai/shardformer/modeling/bloom.py +++ b/colossalai/shardformer/modeling/bloom.py @@ -23,6 +23,10 @@ from transformers.models.bloom.modeling_bloom import ( from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward +from colossalai.shardformer.shard import ShardConfig + +logger = logging.get_logger(__name__) def build_bloom_alibi_tensor_fn(process_group: ProcessGroup) -> torch.Tensor: @@ -111,6 +115,7 @@ class BloomPipelineForwards: stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, **deprecated_arguments, ) -> Union[Tuple[torch.Tensor, ...], 'BaseModelOutputWithPastAndCrossAttentions']: @@ -205,6 +210,13 @@ class BloomPipelineForwards: past_key_values_length=past_key_values_length, ) + # split the input tensor along sequence dimension + # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] + if shard_config.enable_sequence_parallelism: + hidden_states = split_forward_gather_backward(hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group) + start_idx, end_idx = stage_index[0], stage_index[1] for i, (block, layer_past) in enumerate(zip(self.h[start_idx:end_idx], past_key_values[start_idx:end_idx]), start=start_idx): @@ -248,6 +260,12 @@ class BloomPipelineForwards: all_self_attentions = all_self_attentions + \ (outputs[2 if use_cache else 1],) + # When sequence parallelism done, gather the output tensor in forward and split it in backward + if shard_config.enable_sequence_parallelism: + hidden_states = gather_forward_split_backward(hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group) + if stage_manager.is_last_stage(): # Add last hidden state hidden_states = self.ln_f(hidden_states) @@ -287,6 +305,7 @@ class BloomPipelineForwards: stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, **deprecated_arguments): r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -327,7 +346,8 @@ class BloomPipelineForwards: return_dict=return_dict, stage_manager=stage_manager, hidden_states=hidden_states, - stage_index=stage_index) + stage_index=stage_index, + shard_config=shard_config) past_key_values = None all_hidden_states = None all_self_attentions = None @@ -380,6 +400,7 @@ class BloomPipelineForwards: stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, **deprecated_arguments, ): r""" @@ -424,6 +445,7 @@ class BloomPipelineForwards: stage_manager=stage_manager, hidden_states=hidden_states, stage_index=stage_index, + shard_config=shard_config, ) past_key_values = None all_hidden_states = None @@ -503,6 +525,7 @@ class BloomPipelineForwards: stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, **deprecated_arguments, ): r""" @@ -547,6 +570,7 @@ class BloomPipelineForwards: stage_manager=stage_manager, hidden_states=hidden_states, stage_index=stage_index, + shard_config=shard_config, ) past_key_values = None all_hidden_states = None @@ -597,6 +621,7 @@ class BloomPipelineForwards: stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, ): r""" start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): @@ -632,6 +657,7 @@ class BloomPipelineForwards: stage_manager=stage_manager, hidden_states=hidden_states, stage_index=stage_index, + shard_config=shard_config, ) past_key_values = None all_hidden_states = None @@ -700,8 +726,7 @@ def get_bloom_flash_attention_forward(enabel_jit_fused=False): fused_qkv = self.query_key_value(hidden_states) (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) - batch_size, tgt_len, _ = hidden_states.size() - assert tgt_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4." + batch_size, tgt_len, _ = query_layer.size() _, kv_length, _, _ = key_layer.size() @@ -896,3 +921,156 @@ def get_jit_fused_bloom_gelu_forward(): return self.bloom_gelu_forward(x, bias) return forward + + +def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig): + + from transformers import BloomModel + + def forward( + self: BloomModel, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **deprecated_arguments, + ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]: + if deprecated_arguments.pop("position_ids", False) is not False: + # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None` + warnings.warn( + "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore" + " passing `position_ids`.", + FutureWarning, + ) + if len(deprecated_arguments) > 0: + raise ValueError(f"Got unexpected arguments: {deprecated_arguments}") + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if past_key_values is None: + past_key_values = tuple([None] * len(self.h)) + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape batch_size x num_heads x N x N + # head_mask has shape n_layer x batch x num_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + hidden_states = self.word_embeddings_layernorm(inputs_embeds) + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + + # Compute alibi tensor: check build_alibi_tensor documentation + seq_length_with_past = seq_length + past_key_values_length = 0 + if past_key_values[0] is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device) + else: + attention_mask = attention_mask.to(hidden_states.device) + + alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype) + + causal_mask = self._prepare_attn_mask( + attention_mask, + input_shape=(batch_size, seq_length), + past_key_values_length=past_key_values_length, + ) + # split the input tensor along sequence dimension + # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] + hidden_states = split_forward_gather_backward(hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group) + + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, use_cache=use_cache, output_attentions=output_attentions) + + return custom_forward + + outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + alibi, + causal_mask, + layer_past, + head_mask[i], + ) + else: + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=causal_mask, + head_mask=head_mask[i], + use_cache=use_cache, + output_attentions=output_attentions, + alibi=alibi, + ) + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + + # When sequence parallelism done, gather the output tensor in forward and split it in backward + hidden_states = gather_forward_split_backward(hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group) + # Add last hidden state + hidden_states = self.ln_f(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + return forward diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index b35764db3..2727272d0 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -12,6 +12,7 @@ from ..modeling.bloom import ( BloomPipelineForwards, build_bloom_alibi_tensor_fn, get_bloom_flash_attention_forward, + get_bloom_sequence_parallel_forward_fn, get_jit_fused_bloom_attention_forward, get_jit_fused_bloom_gelu_forward, get_jit_fused_bloom_mlp_forward, @@ -43,6 +44,7 @@ class BloomPolicy(Policy): policy = {} + use_sequence_parallel = self.shard_config.enable_sequence_parallelism if self.shard_config.enable_tensor_parallelism: policy[BloomBlock] = ModulePolicyDescription(attribute_replacement={ "self_attention.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, @@ -53,11 +55,11 @@ class BloomPolicy(Policy): SubModuleReplacementDescription( suffix="self_attention.query_key_value", target_module=col_nn.Linear1D_Col, - ), + kwargs={'seq_parallel': use_sequence_parallel}), SubModuleReplacementDescription( suffix="self_attention.dense", target_module=col_nn.Linear1D_Row, - ), + kwargs={'seq_parallel': use_sequence_parallel}), SubModuleReplacementDescription( suffix="self_attention.attention_dropout", target_module=col_nn.DropoutForParallelInput, @@ -65,11 +67,11 @@ class BloomPolicy(Policy): SubModuleReplacementDescription( suffix="mlp.dense_h_to_4h", target_module=col_nn.Linear1D_Col, - ), + kwargs={'seq_parallel': use_sequence_parallel}), SubModuleReplacementDescription( suffix="mlp.dense_4h_to_h", target_module=col_nn.Linear1D_Row, - ), + kwargs={'seq_parallel': use_sequence_parallel}), ]) policy[BloomModel] = ModulePolicyDescription( @@ -116,6 +118,12 @@ class BloomPolicy(Policy): policy=policy, target_key=BloomBlock) + if use_sequence_parallel: + self.append_or_create_method_replacement( + description={'forward': get_bloom_sequence_parallel_forward_fn(self.shard_config)}, + policy=policy, + target_key=BloomModel) + if self.shard_config.enable_flash_attention: policy[BloomAttention] = ModulePolicyDescription(method_replacement={ 'forward': get_bloom_flash_attention_forward(), @@ -154,7 +162,13 @@ class BloomPolicy(Policy): layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages) stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) - method_replacement = {'forward': partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} + method_replacement = { + 'forward': + partial(new_forward, + stage_manager=stage_manager, + stage_index=stage_index, + shard_config=self.shard_config) + } self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index a36e878c6..900f8475c 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -58,3 +58,4 @@ class ShardConfig: self.enable_fused_normalization = True self.enable_flash_attention = True self.enable_jit_fused = True + self.enable_sequence_parallelism = True From a27e0bb494c1260678df0587419913340fda0c1d Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Fri, 18 Aug 2023 18:04:55 +0800 Subject: [PATCH 07/33] [shardformer] bert support sequence parallel. (#4455) * [shardformer] bert support sequence parallel [shardformer] bert support sequence parallel [shardformer] bert support sequence parallel [shardformer] bert support sequence parallel [shardformer] bert support sequence parallel [shardformer] bert support sequence parallel [shardformer] bert support sequence parallel [shardformer] bert support sequence parallel [shardformer] bert support sequence parallel * [shardformer] bert support sequence parallel [shardformer] bert support sequence parallel [shardformer] bert support sequence parallel * [shardformer] bert support sequence parallel --- colossalai/shardformer/layer/_operation.py | 6 +- colossalai/shardformer/modeling/bert.py | 246 ++++++++++++++++++--- colossalai/shardformer/policies/bert.py | 24 +- 3 files changed, 234 insertions(+), 42 deletions(-) diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index fc13aca79..f1f48273c 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -154,7 +154,7 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function): """ @staticmethod - def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap): + def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap=True): ctx.save_for_backward(input_, weight) ctx.use_bias = bias is not None ctx.process_group = process_group @@ -217,9 +217,7 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function): # do all gather in default stream input_ = input_.contiguous() world_size = dist.get_world_size(process_group) - rank = dist.get_rank(process_group) tensor_list = [torch.empty_like(input_) for _ in range(world_size)] - tensor_list[rank] = input_ gather_handle = dist.all_gather(tensor_list, input_, group=process_group, async_op=True) # calculate gradient in calculate_stream @@ -469,9 +467,7 @@ def _gather(input_, dim=-1, process_group=None): # all gather input_ = input_.contiguous() - rank = dist.get_rank(process_group) tensor_list = [torch.empty_like(input_) for _ in range(world_size)] - tensor_list[rank] = input_ torch.distributed.all_gather(tensor_list, input_, group=process_group) # concat diff --git a/colossalai/shardformer/modeling/bert.py b/colossalai/shardformer/modeling/bert.py index 5bd1c531c..d88661953 100644 --- a/colossalai/shardformer/modeling/bert.py +++ b/colossalai/shardformer/modeling/bert.py @@ -1,6 +1,6 @@ import math import warnings -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Union import torch from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss @@ -29,6 +29,8 @@ from transformers.models.bert.modeling_bert import ( from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer import ShardConfig +from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward class BertPipelineForwards: @@ -56,6 +58,7 @@ class BertPipelineForwards: stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, # this is from the previous stage stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, ): # TODO(jianghai): add explaination of the output here. r""" @@ -177,6 +180,14 @@ class BertPipelineForwards: start_idx, end_idx = stage_index[0], stage_index[1] # layer_outputs layer_outputs = hidden_states if hidden_states is not None else None + + # split the input tensor along sequence dimension + # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] + if shard_config is not None and shard_config.enable_sequence_parallelism: + hidden_states = split_forward_gather_backward(hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group) + for idx, encoder_layer in enumerate(self.encoder.layer[start_idx:end_idx], start=start_idx): if stage_manager.is_first_stage() and idx == 0: encoder_attention_mask = encoder_extended_attention_mask @@ -223,11 +234,17 @@ class BertPipelineForwards: all_cross_attentions = all_cross_attentions + \ (layer_outputs[2],) + # When sequence parallelism done, gather the output tensor in forward and split it in backward + if shard_config is not None and shard_config.enable_sequence_parallelism: + hidden_states = gather_forward_split_backward(hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group) + if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) # end of a stage loop - sequence_output = layer_outputs[0] if layer_outputs is not None else None + sequence_output = hidden_states if hidden_states is not None else None if stage_manager.is_last_stage(): pooled_output = self.pooler(sequence_output) if self.pooler is not None else None @@ -268,6 +285,7 @@ class BertPipelineForwards: hidden_states: Optional[torch.FloatTensor] = None, stage_manager: Optional[PipelineStageManager] = None, stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, ): logger = logging.get_logger(__name__) @@ -294,6 +312,7 @@ class BertPipelineForwards: stage_manager=stage_manager, hidden_states=hidden_states if hidden_states is not None else None, stage_index=stage_index, + shard_config=shard_config, ) past_key_values = None all_hidden_states = None @@ -350,6 +369,7 @@ class BertPipelineForwards: hidden_states: Optional[torch.FloatTensor] = None, stage_manager: Optional[PipelineStageManager] = None, stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, ): r""" encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): @@ -404,7 +424,8 @@ class BertPipelineForwards: return_dict=return_dict, stage_manager=stage_manager, hidden_states=hidden_states if hidden_states is not None else None, - stage_index=stage_index) + stage_index=stage_index, + shard_config=shard_config) past_key_values = None all_hidden_states = None all_self_attentions = None @@ -457,6 +478,7 @@ class BertPipelineForwards: hidden_states: Optional[torch.Tensor] = None, stage_manager: Optional[PipelineStageManager] = None, stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, ): r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -491,6 +513,7 @@ class BertPipelineForwards: hidden_states=hidden_states, stage_manager=stage_manager, stage_index=stage_index, + shard_config=shard_config, ) if stage_manager.is_last_stage(): @@ -532,6 +555,7 @@ class BertPipelineForwards: hidden_states: Optional[torch.Tensor] = None, stage_manager: Optional[PipelineStageManager] = None, stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, **kwargs, ): # -> Union[Tuple[torch.Tensor], NextSentencePredictorOutput]: @@ -594,7 +618,8 @@ class BertPipelineForwards: return_dict=return_dict, hidden_states=hidden_states, stage_manager=stage_manager, - stage_index=stage_index) + stage_index=stage_index, + shard_config=shard_config) if stage_manager.is_last_stage(): pooled_output = outputs[1] @@ -636,6 +661,7 @@ class BertPipelineForwards: hidden_states: Optional[torch.Tensor] = None, stage_manager: Optional[PipelineStageManager] = None, stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, ): r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): @@ -666,7 +692,8 @@ class BertPipelineForwards: return_dict=return_dict, hidden_states=hidden_states, stage_manager=stage_manager, - stage_index=stage_index) + stage_index=stage_index, + shard_config=shard_config) if stage_manager.is_last_stage(): pooled_output = outputs[1] @@ -726,6 +753,7 @@ class BertPipelineForwards: hidden_states: Optional[torch.Tensor] = None, stage_manager: Optional[PipelineStageManager] = None, stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, ): r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -742,21 +770,20 @@ class BertPipelineForwards: logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') output_hidden_states = False - outputs = BertPipelineForwards.bert_model_forward( - self.bert, - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - hidden_states=hidden_states, - stage_manager=stage_manager, - stage_index=stage_index, - ) + outputs = BertPipelineForwards.bert_model_forward(self.bert, + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + hidden_states=hidden_states, + stage_manager=stage_manager, + stage_index=stage_index, + shard_config=shard_config) if stage_manager.is_last_stage(): sequence_output = outputs[0] @@ -799,6 +826,7 @@ class BertPipelineForwards: hidden_states: Optional[torch.Tensor] = None, stage_manager: Optional[PipelineStageManager] = None, stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, ): r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): @@ -843,6 +871,7 @@ class BertPipelineForwards: hidden_states=hidden_states, stage_manager=stage_manager, stage_index=stage_index, + shard_config=shard_config, ) if stage_manager.is_last_stage(): pooled_output = outputs[1] @@ -886,6 +915,7 @@ class BertPipelineForwards: hidden_states: Optional[torch.Tensor] = None, stage_manager: Optional[PipelineStageManager] = None, stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, ): # NOTE: the arg start_position and end_position are used only for the last stage r""" @@ -909,21 +939,20 @@ class BertPipelineForwards: logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') output_hidden_states = False - outputs = BertPipelineForwards.bert_model_forward( - self.bert, - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - hidden_states=hidden_states, - stage_manager=stage_manager, - stage_index=stage_index, - ) + outputs = BertPipelineForwards.bert_model_forward(self.bert, + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + hidden_states=hidden_states, + stage_manager=stage_manager, + stage_index=stage_index, + shard_config=shard_config) if stage_manager.is_last_stage(): sequence_output = outputs[0] @@ -1101,3 +1130,150 @@ def get_jit_fused_bert_output_forward(): return hidden_states return forward + + +def bert_sequence_parallel_forward_fn(shard_config: ShardConfig): + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: + r""" + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.config.is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + batch_size, seq_length = input_shape + device = input_ids.device if input_ids is not None else inputs_embeds.device + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) + + if token_type_ids is None: + if hasattr(self.embeddings, "token_type_ids"): + buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + + # split the input tensor along sequence dimension + # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] + embedding_output = split_forward_gather_backward(embedding_output, + dim=1, + process_group=shard_config.tensor_parallel_process_group) + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = encoder_outputs[0] + + # When sequence parallelism done, gather the output tensor in forward and split it in backward + sequence_output = gather_forward_split_backward(sequence_output, + dim=1, + process_group=shard_config.tensor_parallel_process_group) + + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + return forward diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index ace9ada39..fe091c658 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -10,6 +10,7 @@ import colossalai.shardformer.layer as col_nn from .._utils import getattr_, setattr_ from ..modeling.bert import ( BertPipelineForwards, + bert_sequence_parallel_forward_fn, get_bert_flash_attention_forward, get_jit_fused_bert_output_forward, get_jit_fused_bert_self_output_forward, @@ -47,13 +48,14 @@ class BertPolicy(Policy): from transformers.models.bert.modeling_bert import ( BertEmbeddings, BertLayer, + BertModel, BertOutput, BertSelfAttention, BertSelfOutput, ) policy = {} - + use_sequence_parallel = self.shard_config.enable_sequence_parallelism if self.shard_config.enable_tensor_parallelism: policy[BertLayer] = ModulePolicyDescription(attribute_replacement={ "attention.self.all_head_size": @@ -69,14 +71,17 @@ class BertPolicy(Policy): SubModuleReplacementDescription( suffix="attention.self.query", target_module=col_nn.Linear1D_Col, + kwargs={"seq_parallel": use_sequence_parallel}, ), SubModuleReplacementDescription( suffix="attention.self.key", target_module=col_nn.Linear1D_Col, + kwargs={"seq_parallel": use_sequence_parallel}, ), SubModuleReplacementDescription( suffix="attention.self.value", target_module=col_nn.Linear1D_Col, + kwargs={"seq_parallel": use_sequence_parallel}, ), SubModuleReplacementDescription( suffix="attention.self.dropout", @@ -85,6 +90,7 @@ class BertPolicy(Policy): SubModuleReplacementDescription( suffix="attention.output.dense", target_module=col_nn.Linear1D_Row, + kwargs={"seq_parallel": use_sequence_parallel}, ), SubModuleReplacementDescription( suffix="attention.output.dropout", @@ -93,10 +99,12 @@ class BertPolicy(Policy): SubModuleReplacementDescription( suffix="intermediate.dense", target_module=col_nn.Linear1D_Col, + kwargs={"seq_parallel": use_sequence_parallel}, ), SubModuleReplacementDescription( suffix="output.dense", target_module=col_nn.Linear1D_Row, + kwargs={"seq_parallel": use_sequence_parallel}, ), SubModuleReplacementDescription( suffix="output.dropout", @@ -115,6 +123,12 @@ class BertPolicy(Policy): ) ]) + if use_sequence_parallel: + self.append_or_create_method_replacement( + description={'forward': bert_sequence_parallel_forward_fn(self.shard_config)}, + policy=policy, + target_key=BertModel) + # optimization configuration if self.shard_config.enable_fused_normalization: # Handle bert layer @@ -205,7 +219,13 @@ class BertPolicy(Policy): layers_per_stage = Policy.distribute_layers(len(module.encoder.layer), stage_manager.num_stages) stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) - method_replacement = {'forward': partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} + method_replacement = { + 'forward': + partial(new_forward, + stage_manager=stage_manager, + stage_index=stage_index, + shard_config=self.shard_config) + } self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) From 8739aa7fa01a9c04743dea813e2cc210e30dd77f Mon Sep 17 00:00:00 2001 From: Jianghai <72591262+CjhHa1@users.noreply.github.com> Date: Fri, 18 Aug 2023 21:29:25 +0800 Subject: [PATCH 08/33] [shardformer] Pipeline/whisper (#4456) * add some base tests and policies * finish whisper base model * add conditional generation * finish basic tests * whisper * finish whisper * finish whisper * del useless whisper test * fix * add argmin to replace * finish revision --- colossalai/shardformer/modeling/whisper.py | 715 +++++++++++++++++- colossalai/shardformer/policies/blip2.py | 9 - colossalai/shardformer/policies/t5.py | 9 +- colossalai/shardformer/policies/whisper.py | 243 +++++- .../test_t5_pipeline_utils.py | 39 + .../test_whisper_pipeline_utils.py | 44 ++ .../test_model/test_shard_llama.py | 2 + .../test_model/test_shard_whisper.py | 148 +++- 8 files changed, 1156 insertions(+), 53 deletions(-) create mode 100644 tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py create mode 100644 tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py diff --git a/colossalai/shardformer/modeling/whisper.py b/colossalai/shardformer/modeling/whisper.py index 0a16c6f78..62f8f7b47 100644 --- a/colossalai/shardformer/modeling/whisper.py +++ b/colossalai/shardformer/modeling/whisper.py @@ -1,7 +1,26 @@ -from typing import Optional, Tuple +import logging +import random +from typing import Dict, List, Optional, Set, Tuple, Union import torch from torch import nn +from torch.nn import CrossEntropyLoss +from transformers.modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + Seq2SeqLMOutput, + Seq2SeqModelOutput, + SequenceClassifierOutput, +) +from transformers.models.whisper.modeling_whisper import ( + WhisperEncoder, + WhisperForAudioClassification, + WhisperForConditionalGeneration, + WhisperModel, +) +from transformers.utils import logging + +from colossalai.pipeline.stage_manager import PipelineStageManager def get_whisper_flash_attention_forward(): @@ -247,3 +266,697 @@ def get_jit_fused_whisper_decoder_layer_forward(): return outputs return forward + + +class WhisperPipelineForwards: + ''' + This class serves as a micro library for forward function substitution of Llama models + under pipeline setting. + ''' + + @staticmethod + def whisper_encoder_forward( + self: WhisperEncoder, + input_features, + attention_mask=None, + head_mask=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + encoder_states=None, + all_attentions=None, + stage_index: Optional[List[int]] = None, + decoder_starting_stage: Optional[int] = None, + ): + r""" + Args: + input_features (`torch.LongTensor` of shape `(batch_size, feature_size, sequence_length)`): + Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be + obtained by loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a + `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into + `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding + and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`] + attention_mask (`torch.Tensor`)`, *optional*): + Whisper does not support masking of the `input_features`, this argument is preserved for compatibility, + but it is not used. By default the silence in the input log mel spectrogram are ignored. + head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + logger = logging.get_logger(__name__) + + stage = stage_manager.stage + at_first_stage = (stage == 0) + at_last_stage = (stage == decoder_starting_stage - 1) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # Process inputs if at the first stage of encoder. + if at_first_stage: + inputs_embeds = nn.functional.gelu(self.conv1(input_features)) + inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds)) + + inputs_embeds = inputs_embeds.permute(0, 2, 1) + embed_pos = self.embed_positions.weight + + hidden_states = inputs_embeds + embed_pos + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + assert head_mask.size()[0] == ( + len(self.layers) + ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." + + else: + if hidden_states is None: + raise ValueError( + "hidden_states shouldn't be None for stages other than the first stage of encoder/decoder.") + + start_idx, end_idx = stage_index[0], stage_index[1] + + for idx in range(start_idx, end_idx): + encoder_layer = self.layers[idx] + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + dropout_probability = random.uniform(0, 1) + if self.training and (dropout_probability < self.layerdrop): # skip the layer + layer_outputs = (None, None) + else: + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(encoder_layer), + hidden_states, + None, + (head_mask[idx] if head_mask is not None else None), + ) + else: + layer_outputs = encoder_layer( + hidden_states, + None, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if at_last_stage: + hidden_states = self.layer_norm(hidden_states) + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput(last_hidden_state=hidden_states, + hidden_states=encoder_states, + attentions=all_attentions) + + else: + return {'hidden_states': hidden_states, 'head_mask': head_mask} + + @staticmethod + def whisper_decoder_forward( + self, + input_ids=None, + attention_mask=None, + encoder_hidden_states=None, + head_mask=None, + cross_attn_head_mask=None, + past_key_values=None, + inputs_embeds=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + decoder_starting_stage: Optional[int] = None, + ): + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`WhisperTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + of the decoder. + head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention + on hidden heads. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of + shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the + cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of + shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing + `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more + control over how to convert `input_ids` indices into associated vectors than the model's internal + embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + logger = logging.get_logger(__name__) + stage = stage_manager.stage + at_first_stage = (stage == decoder_starting_stage) + at_last_stage = (stage == stage_manager.num_stages - 1) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + next_decoder_cache = () if use_cache else None + + # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): + if attn_mask is not None: + assert attn_mask.size()[0] == (len(self.layers)), ( + f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}.") + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if at_first_stage: + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + # embed positions + if input_ids is not None: + positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length) + else: + positions = self.embed_positions(inputs_embeds, past_key_values_length=past_key_values_length) + + attention_mask = self._prepare_decoder_attention_mask(attention_mask, input_shape, inputs_embeds, + past_key_values_length) + + hidden_states = inputs_embeds + positions + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`..." + ) + use_cache = False + + else: + + if hidden_states is None: + raise ValueError( + "hidden_states shouldn't be None for stages other than the first stage of encoder/decoder.") + input_shape = hidden_states.size()[:-1] + + attention_mask = self._prepare_decoder_attention_mask(attention_mask, input_shape, hidden_states, + past_key_values_length) + + start_idx, end_idx = stage_index[0], stage_index[1] + + for idx in range(start_idx, end_idx): + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + decoder_layer = self.layers[idx] + + if output_hidden_states: + all_hidden_states += (hidden_states,) + dropout_probability = random.uniform(0, 1) + if self.training and (dropout_probability < self.layerdrop): + continue + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, output_attentions, use_cache) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + encoder_hidden_states, + None, # encoder attention mask + head_mask[idx] if head_mask is not None else None, + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, + None, # past_key_value + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=(cross_attn_head_mask[idx] + if cross_attn_head_mask is not None else None), + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) + + if at_last_stage: + hidden_states = self.layer_norm(hidden_states) + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple( + v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + if v is not None) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + ) + + else: + return { + 'head_mask': head_mask, + 'cross_attn_head_mask': cross_attn_head_mask, + 'hidden_states': hidden_states, + } + + @staticmethod + def whisper_model_forward( + self: WhisperModel, + input_features: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + decoder_starting_stage: Optional[int] = None, + ): + r""" + Returns: + + Example: + ```python + >>> import torch + >>> from transformers import AutoFeatureExtractor, WhisperModel + >>> from datasets import load_dataset + + >>> model = WhisperModel.from_pretrained("openai/whisper-base") + >>> feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-base") + >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + >>> inputs = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt") + >>> input_features = inputs.input_features + >>> decoder_input_ids = torch.tensor([[1, 1]]) * model.config.decoder_start_token_id + >>> last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state + >>> list(last_hidden_state.shape) + [1, 2, 512] + ```""" + # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. + if past_key_values: + logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.') + past_key_values = None + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if use_cache: + logger.warning_once('use_cache=True is not supported for pipeline models at the moment.') + use_cache = False + + logger = logging.get_logger(__name__) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + in_decoder = stage_manager.stage >= decoder_starting_stage + if not in_decoder: + if encoder_outputs is None: + input_features = self._mask_input_features(input_features, attention_mask=attention_mask) + + encoder_outputs = WhisperPipelineForwards.whisper_encoder_forward( + self.encoder, + input_features, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + decoder_starting_stage=decoder_starting_stage) + + if stage_manager.stage == decoder_starting_stage - 1: + # last stage of encoder + return {'encoder_hidden_states': encoder_outputs[0]} + else: + return encoder_outputs + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + at_last_decoder_stage = stage_manager.is_last_stage() + at_first_decoder_stage = stage_manager.stage == decoder_starting_stage + if encoder_outputs is not None: + encoder_hidden_states = encoder_outputs[0] + elif encoder_hidden_states is None: + raise ValueError("Non-empty encoder_hidden_states should be passed in at decoder stages.") + + if not at_first_decoder_stage and hidden_states is None: + raise ValueError("If not at the first layer of decoder, non-empty hidden_states must be provided.") + + # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + decoder_outputs = WhisperPipelineForwards.whisper_decoder_forward(self.decoder, + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_hidden_states, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + decoder_starting_stage=decoder_starting_stage) + + # Directly return outputs of overloaded Whisper forward if not at last stage. + if not at_last_decoder_stage: + # encoder_hidden_states should be passed to the next stage + decoder_outputs['encoder_hidden_states'] = encoder_hidden_states + return decoder_outputs + + if not return_dict: + return decoder_outputs + encoder_outputs + + return Seq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_hidden_states, + ) + + @staticmethod + def whisper_for_conditional_generation_forward( + self: WhisperForConditionalGeneration, + input_features: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + decoder_starting_stage: Optional[int] = None, + ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` + or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is + only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> import torch + >>> from transformers import AutoProcessor, WhisperForConditionalGeneration + >>> from datasets import load_dataset + + >>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en") + >>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en") + + >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + + >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="pt") + >>> input_features = inputs.input_features + + >>> generated_ids = model.generate(inputs=input_features) + + >>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] + >>> transcription + ' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.' + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if labels is not None: + if decoder_input_ids is None and decoder_inputs_embeds is None: + decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id, + self.config.decoder_start_token_id) + in_decoder = stage_manager.stage >= decoder_starting_stage + at_last_decoder_stage = stage_manager.is_last_stage() + outputs = WhisperPipelineForwards.whisper_model_forward(self.model, + input_features, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + encoder_outputs=encoder_outputs, + decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + stage_index=stage_index, + decoder_starting_stage=decoder_starting_stage) + if not in_decoder: + return outputs + + if not at_last_decoder_stage: + # encoder_hidden_states should be passed to the next stage + outputs['encoder_hidden_states'] = encoder_hidden_states + return outputs + + lm_logits = self.proj_out(outputs[0]) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + # move labels to correct device to enable PP + labels = labels.to(lm_logits.device) + loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.reshape(-1)) + + if not return_dict: + output = (lm_logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return Seq2SeqLMOutput( + loss=loss, + logits=lm_logits, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + @staticmethod + def whisper_for_audio_classification_forward( + self: WhisperForAudioClassification, + input_features: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + encoder_states=None, + all_attentions=None, + stage_index: Optional[List[int]] = None, + decoder_starting_stage: Optional[int] = None, + ): + r""" + This function is modified on the basis of transformers.models.whisper.modeling_whisper.WhisperForAudioClassification.forward. + Please refer to original code of transformers for more details. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # audio_classification only holds encoder + encoder_outputs = WhisperPipelineForwards.whisper_encoder_forward( + self.encoder, + input_features, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + decoder_starting_stage=decoder_starting_stage, + ) + + if not stage_manager.is_last_stage(): + return encoder_outputs + + if self.config.use_weighted_layer_sum: + hidden_states = torch.stack(encoder_outputs, dim=1) + norm_weights = nn.functional.softmax(self.layer_weights, dim=-1) + hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1) + else: + hidden_states = encoder_outputs[0] + + hidden_states = self.projector(hidden_states) + pooled_output = hidden_states.mean(dim=1) + + logits = self.classifier(pooled_output) + + loss = None + + if labels is not None: + loss_fct = CrossEntropyLoss() + # move labels to correct device to enable PP + labels = labels.to(logits.device) + loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + encoder_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) diff --git a/colossalai/shardformer/policies/blip2.py b/colossalai/shardformer/policies/blip2.py index 50356302e..3610e2c41 100644 --- a/colossalai/shardformer/policies/blip2.py +++ b/colossalai/shardformer/policies/blip2.py @@ -304,15 +304,6 @@ class BlipPolicy(Policy): return policy def postprocess(self): - binding_map = { - 'language_model.model.decoder.embed_tokens': 'language_model.lm_head', - } - - for k, v in binding_map.items(): - src_mod = getattr_(self.model, k) - dst_mod = getattr_(self.model, v) - dst_mod.weight = src_mod.weight - return self.model diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py index 2ef52c214..651883d35 100644 --- a/colossalai/shardformer/policies/t5.py +++ b/colossalai/shardformer/policies/t5.py @@ -1,6 +1,7 @@ from functools import partial from typing import Callable, Dict, List, Optional, Tuple +import numpy as np from torch import Tensor, nn from colossalai.shardformer.layer import ( @@ -228,13 +229,7 @@ class T5BasePolicy(Policy): def objective(num_encoder_stages): return abs(num_encoder_layers / num_encoder_stages - num_decoder_layers / (num_stages - num_encoder_stages)) - num_encoder_stages = 0 - optimal_diff = 2**31 - 1 - for i in range(1, num_stages): - attempt = objective(i) - if attempt < optimal_diff: - num_encoder_stages = i - optimal_diff = attempt + num_encoder_stages = np.argmin([objective(i) for i in range(1, num_stages)]) + 1 num_decoder_stages = num_stages - num_encoder_stages encoder_distribution = Policy.distribute_layers(num_encoder_layers, num_encoder_stages) diff --git a/colossalai/shardformer/policies/whisper.py b/colossalai/shardformer/policies/whisper.py index 2ac7a49fd..a33f929f1 100644 --- a/colossalai/shardformer/policies/whisper.py +++ b/colossalai/shardformer/policies/whisper.py @@ -1,10 +1,16 @@ +from functools import partial +from typing import Callable, Dict, List, Tuple + +import numpy as np import torch.nn as nn +from torch import Tensor import colossalai.shardformer.layer as col_nn from .._utils import getattr_, setattr_ from ..modeling.jit import get_jit_fused_dropout_add_func from ..modeling.whisper import ( + WhisperPipelineForwards, get_jit_fused_whisper_decoder_layer_forward, get_jit_fused_whisper_encoder_layer_forward, get_whisper_flash_attention_forward, @@ -12,7 +18,8 @@ from ..modeling.whisper import ( from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = [ - 'WhisperPolicy', 'WhisperModelPolicy', 'WhisperForConditionalGenerationPolicy', 'WhisperForAudioClassification' + 'WhisperPolicy', 'WhisperModelPolicy', 'WhisperForConditionalGenerationPolicy', + 'WhisperForAudioClassificationPolicy' ] @@ -223,6 +230,146 @@ class WhisperPolicy(Policy): def postprocess(self): return self.model + @staticmethod + def distribute_whisper_layers(num_encoder_layers: int, num_decoder_layers: int, + num_stages: int) -> Tuple[List[int], int]: + """ + Distribute whisper layers into stages when pipeline parallel is used. + Return the layer distribution as a list and the starting stage of decoder. + If decoder doesn't exist, returned decoder starting stage is set to num_encoder_layers. + """ + + # number of encoder layers must be a positive integer + if num_encoder_layers <= 0: + raise ValueError("The number of encoder layers for whisper must be a positive integer.") + + # number of layers should be large enough to fill in every stage + if num_encoder_layers + num_decoder_layers < num_stages: + raise ValueError("The total number of layers can't be smaller than number of stages.") + + # in the case of whisperEncoderModel, set decoder starting stage to num_stages since it doesn't exist + if num_decoder_layers == 0: + return Policy.distribute_layers(num_encoder_layers, num_stages), num_stages + + # the number of stages distributed between encoder and decoder is optmized in this way: + # num_encoder_stages = argmin(abs(num_encoder_layers / encoder_stages - num_decoder_layers / decoder_stages)) + # s.t. num_encoder_stages + num_decoder_stages = num_stages, num_encoder_stages >= 1, num_decoder_stages >= 1 + def objective(num_encoder_stages): + return abs(num_encoder_layers / num_encoder_stages - num_decoder_layers / (num_stages - num_encoder_stages)) + + num_encoder_stages = np.argmin([objective(i) for i in range(1, num_stages)]) + 1 + num_decoder_stages = num_stages - num_encoder_stages + + encoder_distribution = Policy.distribute_layers(num_encoder_layers, num_encoder_stages) + decoder_distribution = Policy.distribute_layers(num_decoder_layers, num_decoder_stages) + return encoder_distribution + decoder_distribution, num_encoder_stages + + @staticmethod + def get_whisper_stage_index(layers_per_stage: List[int], stage: int, + decoder_starting_stage: int) -> Tuple[bool, int, int]: + """ + Input the distribution of layers among stages, the current stage and the first stage of decoder. + Return the starting/ending idx of layers in encoder/decoder + """ + if stage < decoder_starting_stage: + return Policy.get_stage_index(layers_per_stage[:decoder_starting_stage], stage) + else: + return Policy.get_stage_index(layers_per_stage[decoder_starting_stage:], stage - decoder_starting_stage) + + def get_held_layers(self) -> List[nn.Module]: + + assert self.pipeline_stage_manager is not None, "pipeline_stage_manager is None" + stage_manager = self.pipeline_stage_manager + + if self.model.__class__.__name__ == 'WhisperModel': + model = self.model + elif self.model.__class__.__name__ == 'WhisperForConditionalGeneration': + model = self.model.model + else: + model = None + + if model: + encoder = self.model.get_encoder() + decoder = self.model.get_decoder() + else: + # whisper for audio classification holds encoder only + encoder = self.model.encoder + decoder = None + + num_encoder_layers = len(encoder.layers) + if decoder: + num_decoder_layers = len(decoder.layers) + else: + num_decoder_layers = 0 + + held_layers = [] + layers_per_stage, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers( + num_encoder_layers, num_decoder_layers, stage_manager.num_stages) + start_idx, end_idx = WhisperPolicy.get_whisper_stage_index(layers_per_stage, stage_manager.stage, + decoder_starting_stage) + + if stage_manager.stage < decoder_starting_stage: + # current stage is in whisper's encoder + if stage_manager.is_first_stage(): + held_layers.append(encoder.embed_positions) + held_layers.append(encoder.conv1) + held_layers.append(encoder.conv2) + if stage_manager.stage == decoder_starting_stage - 1: + held_layers.append(encoder.layer_norm) + held_layers.extend(encoder.layers[start_idx:end_idx]) + else: + # current stage is in whisper's decoder + # TODO:(Jianghai) We divide encoder and decoder layers into different parts here, + # the case encoder and decoder put in same stage should be add in the future. + if stage_manager.stage == decoder_starting_stage: + held_layers.append(decoder.embed_tokens) + held_layers.append(decoder.embed_positions) + if stage_manager.is_last_stage(): + held_layers.append(decoder.layer_norm) + held_layers.extend(decoder.layers[start_idx:end_idx]) + return held_layers + + def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: + """If under pipeline parallel setting, replacing the original forward method of huggingface + to customized forward method, and add this changing to policy.""" + if not self.pipeline_stage_manager: + raise ValueError("set_pipeline_forward method can only be called when pipeline parallel is enabled.") + stage_manager = self.pipeline_stage_manager + + if self.model.__class__.__name__ == 'WhisperModel': + model = self.model + elif self.model.__class__.__name__ == 'WhisperForConditionalGeneration': + model = self.model.model + else: + model = None + + if model: + encoder = self.model.get_encoder() + decoder = self.model.get_decoder() + else: + encoder = self.model.encoder + decoder = None + + num_encoder_layers = len(encoder.layers) + if decoder: + num_decoder_layers = len(decoder.layers) + else: + num_decoder_layers = 0 + + layers_per_stage, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers( + num_encoder_layers, num_decoder_layers, stage_manager.num_stages) + stage_index = WhisperPolicy.get_whisper_stage_index(layers_per_stage, stage_manager.stage, + decoder_starting_stage) + + method_replacement = { + 'forward': + partial(new_forward, + stage_manager=stage_manager, + stage_index=stage_index, + decoder_starting_stage=decoder_starting_stage) + } + self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) + # WhisperModel class WhisperModelPolicy(WhisperPolicy): @@ -230,6 +377,24 @@ class WhisperModelPolicy(WhisperPolicy): def __init__(self) -> None: super().__init__() + def module_policy(self): + from transformers import WhisperModel + policy = super().module_policy() + + if self.pipeline_stage_manager is not None: + self.set_pipeline_forward(model_cls=WhisperModel, + new_forward=WhisperPipelineForwards.whisper_model_forward, + policy=policy) + + return policy + + def get_held_layers(self) -> List[nn.Module]: + return super().get_held_layers() + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + "no shared params in whisper model" + return [] + # WhisperForConditionalGeneration class WhisperForConditionalGenerationPolicy(WhisperPolicy): @@ -238,20 +403,82 @@ class WhisperForConditionalGenerationPolicy(WhisperPolicy): super().__init__() def module_policy(self): - module_policy = super().module_policy() - module_policy = self.add_lm_head_policy(module_policy) - return module_policy + from transformers import WhisperForConditionalGeneration + policy = super().module_policy() + policy = self.add_lm_head_policy(policy) + + if self.pipeline_stage_manager is not None: + self.set_pipeline_forward(model_cls=WhisperForConditionalGeneration, + new_forward=WhisperPipelineForwards.whisper_for_conditional_generation_forward, + policy=policy) + return policy def postprocess(self): - binding_map = {"model.decoder.embed_tokens.weight": "proj_out.weight"} - for k, v in binding_map.items(): - param = getattr_(self.model, k) - setattr_(self.model, v, param) return self.model + def get_held_layers(self) -> List[nn.Module]: + held_layers = super().get_held_layers() + if self.pipeline_stage_manager.is_last_stage(): + held_layers.append(self.model.proj_out) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + module = self.model + model = module.model + + if model: + encoder = self.model.get_encoder() + decoder = self.model.get_decoder() + else: + encoder = self.model.encoder + decoder = None + + num_encoder_layers = len(encoder.layers) + if decoder: + num_decoder_layers = len(decoder.layers) + else: + num_decoder_layers = 0 + + stage_manager = self.pipeline_stage_manager + if stage_manager is not None and stage_manager.num_stages > 1: + _, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers(num_encoder_layers, num_decoder_layers, + stage_manager.num_stages) + shared_params = [] + shared_embedding = {} + if id(module.proj_out) == id(model.decoder.embed_tokens): + shared_embedding[decoder_starting_stage] = model.decoder.embed_tokens + shared_embedding[stage_manager.num_stages - 1] = module.proj_out + if len(shared_embedding) > 0: + shared_params.append(shared_embedding) + return shared_params + return [] + # WhisperForAudioClassification class WhisperForAudioClassificationPolicy(WhisperPolicy): def __init__(self) -> None: super().__init__() + + def preprocess(self): + return self.model + + def module_policy(self): + from transformers import WhisperForAudioClassification + policy = super().module_policy() + + if self.pipeline_stage_manager is not None: + self.set_pipeline_forward(model_cls=WhisperForAudioClassification, + new_forward=WhisperPipelineForwards.whisper_for_audio_classification_forward, + policy=policy) + return policy + + def get_held_layers(self) -> List[nn.Module]: + held_layers = super().get_held_layers() + if self.pipeline_stage_manager.is_last_stage(): + held_layers.append(self.model.projector) + held_layers.append(self.model.classifier) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + return [] diff --git a/tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py b/tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py new file mode 100644 index 000000000..0cbb852b9 --- /dev/null +++ b/tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py @@ -0,0 +1,39 @@ +from colossalai.shardformer.policies.t5 import T5BasePolicy + + +def test_t5_pipeline_distribution(): + num_test_cases = 8 + test_dict = { + 'num_encoder_layers': [2, 1, 3, 2, 3, 2, 10, 5], + 'num_decoder_layers': [2, 8, 0, 2, 1, 5, 6, 22], + 'num_stages': [2, 2, 2, 4, 4, 4, 8, 8], + 'decoder_starting_stage': [1, 1, 2, 2, 3, 1, 5, 2] + } + + for i in range(num_test_cases): + _, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(test_dict['num_encoder_layers'][i], + test_dict['num_decoder_layers'][i], + test_dict['num_stages'][i]) + assert test_dict['decoder_starting_stage'][i] == decoder_starting_stage + + +def test_t5_pipeline_layers(): + num_test_cases = 4 + test_dict = { + 'num_encoder_layers': [2, 3, 2, 4], + 'num_decoder_layers': [2, 0, 2, 8], + 'num_stages': [2, 2, 4, 4], + 'layers_per_stage': [[[0, 2], [0, 2]], [[0, 1], [1, 3]], [[0, 1], [1, 2], [0, 1], [1, 2]], + [[0, 4], [0, 3], [3, 6], [6, 8]]] + } + + for i in range(num_test_cases): + layers_per_stage, decoder_starting_stage = T5BasePolicy.distribute_t5_layers( + test_dict['num_encoder_layers'][i], test_dict['num_decoder_layers'][i], test_dict['num_stages'][i]) + + for stage in range(test_dict['num_stages'][i]): + start_idx, end_idx = test_dict['layers_per_stage'][i][stage] + predicted_start, predicted_end = T5BasePolicy.get_t5_stage_index(layers_per_stage, stage, + decoder_starting_stage) + assert start_idx == predicted_start + assert end_idx == predicted_end diff --git a/tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py b/tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py new file mode 100644 index 000000000..395519e97 --- /dev/null +++ b/tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py @@ -0,0 +1,44 @@ +from colossalai.shardformer.policies.whisper import WhisperPolicy + + +def test_whisper_pipeline_distribution(): + num_test_cases = 8 + test_dict = { + 'num_encoder_layers': [2, 1, 3, 2, 3, 2, 10, 5], + 'num_decoder_layers': [2, 8, 0, 2, 1, 5, 6, 22], + 'num_stages': [2, 2, 2, 4, 4, 4, 8, 8], + 'decoder_starting_stage': [1, 1, 2, 2, 3, 1, 5, 2] + } + + for i in range(num_test_cases): + _, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers(test_dict['num_encoder_layers'][i], + test_dict['num_decoder_layers'][i], + test_dict['num_stages'][i]) + assert test_dict['decoder_starting_stage'][i] == decoder_starting_stage + + +def test_whisper_pipeline_layers(): + num_test_cases = 4 + test_dict = { + 'num_encoder_layers': [2, 3, 2, 4], + 'num_decoder_layers': [2, 0, 2, 8], + 'num_stages': [2, 2, 4, 4], + 'layers_per_stage': [[[0, 2], [0, 2]], [[0, 1], [1, 3]], [[0, 1], [1, 2], [0, 1], [1, 2]], + [[0, 4], [0, 3], [3, 6], [6, 8]]] + } + + for i in range(num_test_cases): + layers_per_stage, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers( + test_dict['num_encoder_layers'][i], test_dict['num_decoder_layers'][i], test_dict['num_stages'][i]) + + for stage in range(test_dict['num_stages'][i]): + start_idx, end_idx = test_dict['layers_per_stage'][i][stage] + predicted_start, predicted_end = WhisperPolicy.get_whisper_stage_index(layers_per_stage, stage, + decoder_starting_stage) + assert start_idx == predicted_start + assert end_idx == predicted_end + + +if __name__ == '__main__': + test_whisper_pipeline_distribution() + test_whisper_pipeline_layers() diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index a433567b3..ec5578a76 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -6,6 +6,7 @@ from torch import distributed as dist import colossalai from colossalai.logging import disable_existing_loggers +from colossalai.shardformer.layer.utils import Randomizer from colossalai.tensor.d_tensor.api import clear_layout_converter from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn from tests.kit.model_zoo import model_zoo @@ -143,6 +144,7 @@ def run_llama_test(test_config): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) clear_layout_converter() + Randomizer.reset_index() torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_whisper.py b/tests/test_shardformer/test_model/test_shard_whisper.py index 9b38ae07b..90e007e34 100644 --- a/tests/test_shardformer/test_model/test_shard_whisper.py +++ b/tests/test_shardformer/test_model/test_shard_whisper.py @@ -3,6 +3,8 @@ import torch import colossalai from colossalai.logging import disable_existing_loggers +from colossalai.shardformer.layer.utils import Randomizer +from colossalai.tensor.d_tensor.api import clear_layout_converter from colossalai.testing import ( assert_hf_output_close, clear_cache_before_run, @@ -11,55 +13,145 @@ from colossalai.testing import ( spawn, ) from tests.kit.model_zoo import model_zoo -from tests.test_shardformer.test_model._utils import build_model, check_grad, run_forward +from tests.test_shardformer.test_model._utils import ( + build_model_from_hybrid_plugin, + check_grad, + check_loss, + check_output_hidden_state, + check_weight, + run_forward_backward_with_hybrid_plugin, +) -def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): +def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): # check forward - org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn, - output_transform_fn, loss_fn) - assert_hf_output_close(org_output, shard_output, ignore_keys='past_key_values', atol=1e-5) + org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \ + build_model_from_hybrid_plugin(model_fn, loss_fn, test_config) - # do backward - org_loss.backward() - shard_loss.backward() + org_loss, org_output, sharded_loss, sharded_output = \ + run_forward_backward_with_hybrid_plugin( + org_model, + sharded_model, + sharded_optimizer, + data_gen_fn, + output_transform_fn, + criterion, + booster) - assert torch.allclose(org_loss, shard_loss, - atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" + stage_manager = booster.plugin.stage_manager + tp_group = booster.plugin.tp_group + + # check last hidden state & loss + if stage_manager is None or stage_manager.is_last_stage(): + if test_config['precision'] == 'fp32': + atol, rtol = 1e-3, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + + if org_model.__class__.__name__ == 'WhisperModel': + check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) + + check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) # unwarp the model if org_model.__class__.__name__ == 'WhisperForConditionalGeneration': whisper = org_model.model - sharded_whisper = sharded_model.model + sharded_whisper = sharded_model.unwrap().model else: whisper = org_model - sharded_whisper = sharded_model + sharded_whisper = sharded_model.unwrap() # check grad if org_model.__class__.__name__ == 'WhisperForAudioClassification': col_layer_for_check = ['encoder.layers[0].self_attn.q_proj'] row_layer_for_check = ['encoder.layers[0].self_attn.out_proj'] else: - col_layer_for_check = ['encoder.layers[0].self_attn.q_proj', 'decoder.layers[0].self_attn.q_proj'] - row_layer_for_check = ['encoder.layers[0].self_attn.out_proj', 'decoder.layers[0].self_attn.out_proj'] - check_grad(whisper, sharded_whisper, col_layer_for_check, atol=1e-6, rtol=1e-5, dim=0, verbose=False) - check_grad(whisper, sharded_whisper, row_layer_for_check, atol=1e-6, rtol=1e-5, dim=1, verbose=False) + col_layer_for_check = [ + 'encoder.layers[0].self_attn.q_proj', + # 'decoder.layers[0].self_attn.q_proj' + ] + row_layer_for_check = [ + 'encoder.layers[0].self_attn.out_proj', + #'decoder.layers[0].self_attn.out_proj' + ] + + # check weights and gradients + if test_config['precision'] == 'fp32': + atol, rtol = 1e-3, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + + if stage_manager is None or stage_manager.is_first_stage(): + check_grad(whisper, sharded_whisper, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0) + check_grad(whisper, sharded_whisper, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1) + + # check weights after optimizer.step() + org_optimizer.step() + sharded_optimizer.step() + if test_config['precision'] == 'fp32': + atol, rtol = 1e-3, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + if stage_manager is None or stage_manager.is_first_stage(): + check_weight(whisper, + sharded_whisper, + row_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=0, + verbose=False) + check_weight(whisper, + sharded_whisper, + col_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=0, + verbose=False) + + torch.cuda.empty_cache() -@parameterize('enable_fused_normalization', [True, False]) -@parameterize('enable_tensor_parallelism', [True, False]) -@parameterize('enable_flash_attention', [True, False]) -@parameterize('enable_jit_fused', [True, False]) -def run_whisper_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused): +# TODO(jianghai) fix fp16 +@parameterize('test_config', [{ + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 2, + 'enable_all_optimization': True, + 'use_lazy_init': True, + 'precision': 'fp32', + 'initial_scale': 1, +}, { + 'tp_size': 1, + 'pp_size': 2, + 'num_microbatches': 4, + 'use_lazy_init': False, + 'precision': 'fp32', + 'initial_scale': 1, +}, { + 'tp_size': 4, + 'pp_size': 1, + 'enable_all_optimization': True, + 'use_lazy_init': False, + 'precision': 'fp32', +}, { + 'tp_size': 1, + 'pp_size': 4, + 'num_microbatches': 4, + 'use_lazy_init': False, + 'precision': 'fp32', +}]) +def run_whisper_test(test_config): sub_model_zoo = model_zoo.get_sub_registry('transformers_whisper') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(model_fn, - enable_fused_normalization=enable_fused_normalization, - enable_tensor_parallelism=enable_tensor_parallelism, - enable_flash_attention=enable_flash_attention, - enable_jit_fused=enable_jit_fused) - check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) + if test_config['pp_size'] > 2 and name == 'transformers_whisper_for_audio_classification': + continue + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() + Randomizer.reset_index() torch.cuda.empty_cache() @@ -73,7 +165,7 @@ def check_whisper(rank, world_size, port): @rerun_if_address_is_in_use() @clear_cache_before_run() def test_whisper(): - spawn(check_whisper, 2) + spawn(check_whisper, 4) if __name__ == "__main__": From 1c7df566e23d5b94512f0777e0475df0a0ae1072 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Mon, 21 Aug 2023 12:04:52 +0800 Subject: [PATCH 09/33] [shardformer] support tp+zero for shardformer (#4472) * support tp+zero/input type cast for hybridplugin * add tp+zero tests * fix bucket arguments --- .../booster/plugin/hybrid_parallel_plugin.py | 89 +++++++++++++------ .../test_model/test_shard_bert.py | 12 ++- .../test_model/test_shard_bloom.py | 10 ++- .../test_model/test_shard_chatglm.py | 10 ++- .../test_model/test_shard_gpt2.py | 10 ++- .../test_model/test_shard_llama.py | 10 ++- .../test_model/test_shard_opt.py | 10 ++- .../test_model/test_shard_t5.py | 12 ++- .../test_model/test_shard_vit.py | 10 ++- 9 files changed, 136 insertions(+), 37 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 155f72dc6..016323ae7 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1,5 +1,6 @@ import random from contextlib import nullcontext +from functools import partial from typing import Any, Callable, Iterator, List, Optional, Tuple, Union import numpy as np @@ -10,6 +11,7 @@ from torch.nn import Module, SyncBatchNorm from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler as LRScheduler +from torch.utils._pytree import tree_map from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler @@ -27,32 +29,49 @@ from .pp_plugin_base import PipelinePluginBase DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2 +def _convert_floating_point(x, dtype: torch.dtype = torch.float16): + if isinstance(x, torch.Tensor) and torch.is_floating_point(x): + return x.to(dtype) + return x + + class HybridParallelModule(ModelWrapper): def __init__(self, module: Module, precision: str, shard_config: ShardConfig, dp_group: ProcessGroup, use_ddp: bool, ddp_config: dict) -> None: + self.stage_manager = shard_config.pipeline_stage_manager self.dp_group = dp_group + shardformer = ShardFormer(shard_config) module, self.shared_params = shardformer.optimize(module) - # TODO(ver217): add input type cast + + # setting process groups for shared parameters self.shared_param_process_groups = [] for shared_param in self.shared_params: if len(shared_param) > 0: self.shared_param_process_groups.append( self.stage_manager.init_process_group_by_stages(list(shared_param.keys()))) + + # setting mixed_precision + self.mixed_precision = None if precision == 'fp16': - module = module.half().cuda() + self.mixed_precision = torch.float16 elif precision == 'bf16': - module = module.to(dtype=torch.bfloat16).cuda() - else: - module = module.cuda() # train without AMP + self.mixed_precision = torch.bfloat16 + if self.mixed_precision is not None: + module = module.to(self.mixed_precision) + module = module.cuda() + # setting input type cast when using mixed precision + self.convert_fn = None + if self.mixed_precision is not None: + self.convert_fn = partial(_convert_floating_point, dtype=self.mixed_precision) + + # setting ddp configs if use_ddp: - # convert model to sync bn module = SyncBatchNorm.convert_sync_batchnorm(module, dp_group) - # wrap the model with PyTorch DDP module = DDP(module, process_group=dp_group, **ddp_config) @@ -78,6 +97,12 @@ class HybridParallelModule(ModelWrapper): dist.all_reduce(p.grad, group=self.dp_group) p.grad.div_(self.dp_group.size()) + def forward(self, *args, **kwargs): + if self.convert_fn is not None: + args = tree_map(self.convert_fn, args) + kwargs = tree_map(self.convert_fn, kwargs) + return super().forward(*args, **kwargs) + def unwrap(self): module = super().unwrap() if isinstance(module, DDP): @@ -180,7 +205,6 @@ class HybridParallelPlugin(PipelinePluginBase): Defaults to 'fp16'. zero_stage (int, optional): The stage of ZeRO for data parallelism. Can only be choosed from [0, 1, 2]. When set to 0, ZeRO will not be used. Defaults to 0. - cpu_offload (bool, optional): Whether to open cpu_offload when using ZeRO. Defaults to False. enable_all_optimization (bool, optional): Whether to switch on all the optimizations supported by Shardformer. Currently all the optimization methods include fused normalization, flash attention and JIT. Defaults to False. @@ -196,12 +220,16 @@ class HybridParallelPlugin(PipelinePluginBase): hysteresis (int, optional): The number of overflows before decreasing loss scale when using AMP. Defaults to 2. max_scale (float, optional): The maximum loss scale of AMP. Defaults to 2**32. max_norm (float, optional): Maximum norm for gradient clipping. Defaults to 0. - broadcast_buffers (bool, optional): Whether to broadcast buffers in the beginning of training. Only for usage of DDP. Defaults to True. - bucket_cap_mb (int, optional): The bucket size in MB. Only for usage of DDP. Defaults to 25. - find_unused_parameters (bool, optional): Whether to find unused parameters. Only for usage of DDP. Defaults to False. - check_reduction (bool, optional): Whether to check reduction. Only for usage of DDP. Defaults to False. - gradient_as_bucket_view (bool, optional): Whether to use gradient as bucket view. Only for usage of DDP. Defaults to False. - static_graph (bool, optional): Whether to use static graph. Only for usage of DDP. Defaults to False. + broadcast_buffers (bool, optional): Whether to broadcast buffers in the beginning of training when using DDP. Defaults to True. + ddp_bucket_cap_mb (int, optional): The bucket size in MB when using DDP. Defaults to 25. + find_unused_parameters (bool, optional): Whether to find unused parameters when using DDP. Defaults to False. + check_reduction (bool, optional): Whether to check reduction when using DDP. Defaults to False. + gradient_as_bucket_view (bool, optional): Whether to use gradient as bucket view when using DDP. Defaults to False. + static_graph (bool, optional): Whether to use static graph when using DDP. Defaults to False. + zero_bucket_size_in_m (int, optional): Gradient reduce bucket size in million elements when using ZeRO. Defaults to 12. + cpu_offload (bool, optional): Whether to open cpu_offload when using ZeRO. Defaults to False. + communication_dtype (torch.dtype, optional): Communication dtype when using ZeRO. If not specified, the dtype of param will be used. Defaults to None. + overlap_communication (bool, optional): Whether to overlap communication and computation when using ZeRO. Defaults to True. """ def __init__(self, @@ -209,7 +237,6 @@ class HybridParallelPlugin(PipelinePluginBase): pp_size: int, precision: str = 'fp16', zero_stage: int = 0, - cpu_offload: bool = False, enable_all_optimization: bool = False, enable_fused_normalization: bool = False, enable_flash_attention: bool = False, @@ -224,12 +251,16 @@ class HybridParallelPlugin(PipelinePluginBase): hysteresis: int = 2, max_scale: float = 2**32, max_norm: float = 0, - broadcast_buffers=True, - bucket_cap_mb=25, - find_unused_parameters=False, - check_reduction=False, - gradient_as_bucket_view=False, - static_graph=False) -> None: + broadcast_buffers: bool = True, + ddp_bucket_cap_mb: int = 25, + find_unused_parameters: bool = False, + check_reduction: bool = False, + gradient_as_bucket_view: bool = False, + static_graph: bool = False, + zero_bucket_size_in_m: int = 12, + cpu_offload: bool = False, + communication_dtype: Optional[torch.dtype] = None, + overlap_communication: bool = True) -> None: super().__init__() assert dist.get_world_size() % ( @@ -239,8 +270,6 @@ class HybridParallelPlugin(PipelinePluginBase): if enable_sequence_parallelism: assert tp_size > 1, 'Sequence parallelism must be enabled when using tensor parallelism' - # TODO(ver217): support zero - assert zero_stage == 0, 'zero is not support yet' self.tp_size = tp_size self.pp_size = pp_size self.dp_size = dist.get_world_size() // (tp_size * pp_size) @@ -282,11 +311,18 @@ class HybridParallelPlugin(PipelinePluginBase): ) self.ddp_config = dict(broadcast_buffers=broadcast_buffers, - bucket_cap_mb=bucket_cap_mb, + bucket_cap_mb=ddp_bucket_cap_mb, find_unused_parameters=find_unused_parameters, check_reduction=check_reduction, gradient_as_bucket_view=gradient_as_bucket_view, static_graph=static_graph) + + self.zero_config = dict(reduce_bucket_size=zero_bucket_size_in_m * 1024 * 1024, + communication_dtype=communication_dtype, + overlap_communication=overlap_communication, + cpu_offload=cpu_offload, + partition_grad=(self.zero_stage == 2)) + self.max_norm = max_norm @property @@ -337,15 +373,16 @@ class HybridParallelPlugin(PipelinePluginBase): model, use_pipeline=self.enable_pipeline_parallelism) else: + assert self.dp_size > 1, "Please use Zero when data parallel size is greater than 1." + assert self.precision != 'fp32', "Please set precision to 'fp16' or 'bf16' when using ZeRO." optimizer = HybridParallelZeroOptimizer(optimizer, model, use_pipeline=self.enable_pipeline_parallelism, - partition_grad=(self.zero_stage == 2), - cpu_offload=self.cpu_offload, dp_process_group=self.dp_group, tp_process_group=self.tp_group, verbose=True, clip_grad_norm=self.max_norm, + **self.zero_config, **self.amp_config) return model, optimizer, criterion, dataloader, lr_scheduler diff --git a/tests/test_shardformer/test_model/test_shard_bert.py b/tests/test_shardformer/test_model/test_shard_bert.py index 49de9cc03..c96701704 100644 --- a/tests/test_shardformer/test_model/test_shard_bert.py +++ b/tests/test_shardformer/test_model/test_shard_bert.py @@ -56,9 +56,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, atol, rtol = 1e-4, 1e-3 else: atol, rtol = 5e-3, 5e-3 - if stage_manager is None or stage_manager.is_first_stage(): - #check_weight(bert.embeddings.word_embeddings, sharded_bert.embeddings.word_embeddings, tp_group, atol=1e-5, rtol=1e-3) - #check_weight(bert.encoder.layer[0].attention.self.query, sharded_bert.encoder.layer[0].attention.self.query, tp_group, atol=5e-3, rtol=1e-3) + if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: check_grad(bert, sharded_bert, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False) check_grad(bert, sharded_bert, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False) @@ -101,6 +99,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'enable_all_optimization': True, 'use_lazy_init': False, 'precision': 'fp32' +}, { + 'tp_size': 2, + 'pp_size': 1, + 'enable_all_optimization': True, + 'use_lazy_init': True, + 'zero_stage': 2, + 'precision': 'fp16', + 'initial_scale': 1 }]) def run_bert_test(test_config): diff --git a/tests/test_shardformer/test_model/test_shard_bloom.py b/tests/test_shardformer/test_model/test_shard_bloom.py index af014a858..bd87be8b7 100644 --- a/tests/test_shardformer/test_model/test_shard_bloom.py +++ b/tests/test_shardformer/test_model/test_shard_bloom.py @@ -53,7 +53,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # check grad row_layer_for_check = ['h[0].self_attention.query_key_value', 'word_embeddings'] col_layer_for_check = ['h[0].self_attention.dense'] - if stage_manager is None or stage_manager.is_first_stage(): + if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: if test_config['precision'] == 'fp32': atol, rtol = 1e-6, 1e-5 else: @@ -101,6 +101,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'enable_all_optimization': True, 'use_lazy_init': False, 'precision': 'fp32' +}, { + 'tp_size': 2, + 'pp_size': 1, + 'enable_all_optimization': True, + 'use_lazy_init': True, + 'zero_stage': 2, + 'precision': 'fp16', + 'initial_scale': 1 }]) def run_bloom_test(test_config): diff --git a/tests/test_shardformer/test_model/test_shard_chatglm.py b/tests/test_shardformer/test_model/test_shard_chatglm.py index 210f775b5..64732e06b 100644 --- a/tests/test_shardformer/test_model/test_shard_chatglm.py +++ b/tests/test_shardformer/test_model/test_shard_chatglm.py @@ -55,7 +55,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # check grad row_layer_for_check = ['encoder.layers[0].self_attention.query_key_value', 'embedding.word_embeddings'] col_layer_for_check = ['encoder.layers[0].self_attention.dense'] - if stage_manager is None or stage_manager.is_first_stage(): + if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: if test_config['precision'] == 'fp32': atol, rtol = 1e-6, 1e-3 else: @@ -125,6 +125,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'enable_all_optimization': True, 'use_lazy_init': False, 'precision': 'fp32' +}, { + 'tp_size': 2, + 'pp_size': 1, + 'enable_all_optimization': True, + 'use_lazy_init': True, + 'zero_stage': 2, + 'precision': 'fp16', + 'initial_scale': 1 }]) def run_chatglm_test(test_config): diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index 0e29f1dd9..c776a80d8 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -56,7 +56,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, row_layer_for_check = ['wte', 'h[0].mlp.c_proj'] # check grad - if stage_manager is None or stage_manager.is_first_stage(): + if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: if test_config['precision'] == 'fp32': atol, rtol = 1e-4, 1e-3 else: @@ -120,6 +120,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'use_lazy_init': True, 'enable_sequence_parallelism': True, 'precision': 'fp32', +}, { + 'tp_size': 2, + 'pp_size': 1, + 'enable_all_optimization': True, + 'use_lazy_init': True, + 'zero_stage': 2, + 'precision': 'fp16', + 'initial_scale': 1 }]) @clear_cache_before_run() def run_gpt2_test(test_config): diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index ec5578a76..7140c4666 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -60,7 +60,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # check grad row_layer_for_check = ['layers[0].self_attn.q_proj', 'embed_tokens'] col_layer_for_check = ['layers[0].self_attn.o_proj'] - if stage_manager is None or stage_manager.is_first_stage(): + if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: if test_config['precision'] == 'fp32': atol, rtol = 1e-6, 1e-4 else: @@ -135,6 +135,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'enable_all_optimization': True, 'use_lazy_init': False, 'precision': 'fp32' +}, { + 'tp_size': 2, + 'pp_size': 1, + 'enable_all_optimization': True, + 'use_lazy_init': True, + 'zero_stage': 2, + 'precision': 'fp16', + 'initial_scale': 1 }]) def run_llama_test(test_config): diff --git a/tests/test_shardformer/test_model/test_shard_opt.py b/tests/test_shardformer/test_model/test_shard_opt.py index 2fb14903b..e6faafdae 100644 --- a/tests/test_shardformer/test_model/test_shard_opt.py +++ b/tests/test_shardformer/test_model/test_shard_opt.py @@ -58,7 +58,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # check grad row_layer_for_check = ['decoder.layers[0].self_attn.q_proj', 'decoder.embed_tokens'] # 'decoder.embed_tokens' col_layer_for_check = ['decoder.layers[0].self_attn.out_proj'] - if stage_manager is None or stage_manager.is_first_stage(): + if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: if test_config['precision'] == 'fp32': atol, rtol = 1e-6, 1e-3 else: @@ -127,6 +127,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'enable_all_optimization': True, 'use_lazy_init': False, 'precision': 'fp32' +}, { + 'tp_size': 2, + 'pp_size': 1, + 'enable_all_optimization': True, + 'use_lazy_init': True, + 'zero_stage': 2, + 'precision': 'fp16', + 'initial_scale': 1 }]) def run_opt_test(test_config): diff --git a/tests/test_shardformer/test_model/test_shard_t5.py b/tests/test_shardformer/test_model/test_shard_t5.py index 234ce812a..599f5a80d 100644 --- a/tests/test_shardformer/test_model/test_shard_t5.py +++ b/tests/test_shardformer/test_model/test_shard_t5.py @@ -55,12 +55,12 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, row_layer_for_check = ['shared', 'encoder.block[0].layer[0].SelfAttention.q'] - # check weights and gradients + # check grad if test_config['precision'] == 'fp32': atol, rtol = 1e-5, 1e-3 else: atol, rtol = 5e-3, 5e-3 - if stage_manager is None or stage_manager.is_first_stage(): + if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: check_grad(t5, sharded_t5, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0) # check weights after optimizer.step() @@ -110,6 +110,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'enable_all_optimization': True, 'use_lazy_init': False, 'precision': 'fp32' +}, { + 'tp_size': 2, + 'pp_size': 1, + 'enable_all_optimization': True, + 'use_lazy_init': True, + 'zero_stage': 2, + 'precision': 'fp16', + 'initial_scale': 1 }]) @clear_cache_before_run() def run_t5_test(test_config): diff --git a/tests/test_shardformer/test_model/test_shard_vit.py b/tests/test_shardformer/test_model/test_shard_vit.py index b9d303841..b27add24c 100644 --- a/tests/test_shardformer/test_model/test_shard_vit.py +++ b/tests/test_shardformer/test_model/test_shard_vit.py @@ -55,7 +55,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # check grad row_layer_for_check = ['encoder.layer[0].attention.attention.query', 'embeddings.patch_embeddings.projection'] col_layer_for_check = ['encoder.layer[0].attention.output.dense'] - if stage_manager is None or stage_manager.is_first_stage(): + if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: if test_config['precision'] == 'fp32': atol, rtol = 1e-5, 1e-3 else: @@ -124,6 +124,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'enable_all_optimization': True, 'use_lazy_init': False, 'precision': 'fp32' +}, { + 'tp_size': 2, + 'pp_size': 1, + 'enable_all_optimization': True, + 'use_lazy_init': False, + 'zero_stage': 2, + 'precision': 'fp16', + 'initial_scale': 1 }]) def run_vit_test(test_config): From 5545114fd84c8aa39b18aa0ad8816ddbc6dab360 Mon Sep 17 00:00:00 2001 From: Jianghai <72591262+CjhHa1@users.noreply.github.com> Date: Tue, 22 Aug 2023 14:13:31 +0800 Subject: [PATCH 10/33] rename chatglm to chatglm2 (#4484) --- colossalai/shardformer/modeling/{chatglm.py => chatglm2.py} | 0 colossalai/shardformer/policies/auto_policy.py | 4 ++-- colossalai/shardformer/policies/{chatglm.py => chatglm2.py} | 4 ++-- tests/kit/model_zoo/transformers/__init__.py | 2 +- tests/kit/model_zoo/transformers/{chatglm.py => chatglm2.py} | 0 .../{test_shard_chatglm.py => test_shard_chatglm2.py} | 0 6 files changed, 5 insertions(+), 5 deletions(-) rename colossalai/shardformer/modeling/{chatglm.py => chatglm2.py} (100%) rename colossalai/shardformer/policies/{chatglm.py => chatglm2.py} (98%) rename tests/kit/model_zoo/transformers/{chatglm.py => chatglm2.py} (100%) rename tests/test_shardformer/test_model/{test_shard_chatglm.py => test_shard_chatglm2.py} (100%) diff --git a/colossalai/shardformer/modeling/chatglm.py b/colossalai/shardformer/modeling/chatglm2.py similarity index 100% rename from colossalai/shardformer/modeling/chatglm.py rename to colossalai/shardformer/modeling/chatglm2.py diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py index eec339c02..2fe49f0d5 100644 --- a/colossalai/shardformer/policies/auto_policy.py +++ b/colossalai/shardformer/policies/auto_policy.py @@ -125,9 +125,9 @@ _POLICY_LIST = { # ChatGLM "colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMModel": - PolicyLocation(file_name="chatglm", class_name="ChatGLMModelPolicy"), + PolicyLocation(file_name="chatglm2", class_name="ChatGLMModelPolicy"), "colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMForConditionalGeneration": - PolicyLocation(file_name="chatglm", class_name="ChatGLMForConditionalGenerationPolicy"), + PolicyLocation(file_name="chatglm2", class_name="ChatGLMForConditionalGenerationPolicy"), } diff --git a/colossalai/shardformer/policies/chatglm.py b/colossalai/shardformer/policies/chatglm2.py similarity index 98% rename from colossalai/shardformer/policies/chatglm.py rename to colossalai/shardformer/policies/chatglm2.py index e6b458936..a15aa856d 100644 --- a/colossalai/shardformer/policies/chatglm.py +++ b/colossalai/shardformer/policies/chatglm2.py @@ -7,7 +7,7 @@ from transformers.modeling_outputs import BaseModelOutputWithPast import colossalai.shardformer.layer as col_nn from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.shardformer.modeling.chatglm import ChatGLMPipelineForwards +from colossalai.shardformer.modeling.chatglm2 import ChatGLMPipelineForwards from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ( ChatGLMForConditionalGeneration, @@ -15,7 +15,7 @@ from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ( GLMBlock, ) -from ..modeling.chatglm import get_flash_core_attention_forward, get_jit_fused_glm_block_forward +from ..modeling.chatglm2 import get_flash_core_attention_forward, get_jit_fused_glm_block_forward from ..modeling.jit import get_jit_fused_dropout_add_func from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription diff --git a/tests/kit/model_zoo/transformers/__init__.py b/tests/kit/model_zoo/transformers/__init__.py index 823ca032f..2a492361b 100644 --- a/tests/kit/model_zoo/transformers/__init__.py +++ b/tests/kit/model_zoo/transformers/__init__.py @@ -2,7 +2,7 @@ from .albert import * from .bert import * from .blip2 import * from .bloom import * -from .chatglm import * +from .chatglm2 import * from .gpt import * from .llama import * from .opt import * diff --git a/tests/kit/model_zoo/transformers/chatglm.py b/tests/kit/model_zoo/transformers/chatglm2.py similarity index 100% rename from tests/kit/model_zoo/transformers/chatglm.py rename to tests/kit/model_zoo/transformers/chatglm2.py diff --git a/tests/test_shardformer/test_model/test_shard_chatglm.py b/tests/test_shardformer/test_model/test_shard_chatglm2.py similarity index 100% rename from tests/test_shardformer/test_model/test_shard_chatglm.py rename to tests/test_shardformer/test_model/test_shard_chatglm2.py From 351351a36eb9d11e5bdb3610b0d3705055d90e7d Mon Sep 17 00:00:00 2001 From: Bin Jia <45593998+FoolPlayer@users.noreply.github.com> Date: Tue, 22 Aug 2023 17:35:35 +0800 Subject: [PATCH 11/33] [shardformer/sequence parallel] not support opt of seq-parallel, add warning and fix a bug in gpt2 pp (#4488) --- colossalai/shardformer/modeling/gpt2.py | 2 +- colossalai/shardformer/policies/opt.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index 722f0f523..8ed367b25 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -148,7 +148,7 @@ class GPT2PipelineForwards: if token_type_ids is not None: token_type_embeds = self.wte(token_type_ids) hidden_states = hidden_states + token_type_embeds - hidden_states = self.drop(hidden_states) + hidden_states = self.drop(hidden_states) output_shape = input_shape + (hidden_states.size(-1),) diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index ba6036bd0..58663553b 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -1,3 +1,4 @@ +import warnings from functools import partial from typing import Callable, Dict, List @@ -39,6 +40,9 @@ class OPTPolicy(Policy): from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder, OPTDecoderLayer policy = {} + if self.shard_config.enable_sequence_parallelism: + self.shard_config.enable_sequence_parallelism = False + warnings.warn("OPT dosen't support sequence parallelism now, will ignore the sequence parallelism flag.") if self.shard_config.enable_tensor_parallelism: policy[OPTDecoder] = ModulePolicyDescription(sub_module_replacement=[ From 59e252ecdbab0fe56fd3bacc9833188fe5285d02 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Tue, 22 Aug 2023 23:59:31 +0800 Subject: [PATCH 12/33] [shardformer] chatglm support sequence parallel (#4482) * [shardformer] chatglm support sequence parallel [shardformer] chatglm support sequence parallel [shardformer] chatglm support sequence parallel [shardformer] chatglm support sequence parallel [shardformer] chatglm support sequence parallel [shardformer] chatglm support sequence parallel * fix fix fix fix --- colossalai/shardformer/layer/linear.py | 10 +- colossalai/shardformer/modeling/chatglm2.py | 135 ++++++++++++++++--- colossalai/shardformer/policies/bert.py | 18 ++- colossalai/shardformer/policies/blip2.py | 23 ++-- colossalai/shardformer/policies/bloom.py | 26 ++-- colossalai/shardformer/policies/chatglm2.py | 101 +++++++++----- colossalai/shardformer/policies/gpt2.py | 6 +- colossalai/shardformer/policies/llama.py | 6 +- colossalai/shardformer/policies/sam.py | 12 +- colossalai/shardformer/policies/vit.py | 12 +- tests/kit/model_zoo/transformers/chatglm2.py | 4 +- 11 files changed, 259 insertions(+), 94 deletions(-) diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py index 69ac3ad25..81c3f973f 100644 --- a/colossalai/shardformer/layer/linear.py +++ b/colossalai/shardformer/layer/linear.py @@ -74,6 +74,7 @@ class Linear1D_Col(ParallelModule): process_group: ProcessGroup = None, gather_output: bool = False, seq_parallel: bool = False, + seq_parallel_dim: int = 1, overlap: bool = False, skip_bias_add: bool = False, weight: Optional[Parameter] = None, @@ -87,6 +88,7 @@ class Linear1D_Col(ParallelModule): self.out_features = out_features self.gather_output = gather_output self.seq_parallel = seq_parallel + self.seq_parallel_dim = seq_parallel_dim self.overlap = overlap self.skip_bias_add = skip_bias_add self.device = device @@ -190,7 +192,8 @@ class Linear1D_Col(ParallelModule): bias = self.bias if not self.skip_bias_add else None if self.seq_parallel: output_parallel = linear_gather_forward_reducescatter_backward(input_parallel, self.weight, bias, - self.process_group, True, 1, self.overlap) + self.process_group, True, + self.seq_parallel_dim, self.overlap) else: output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True) @@ -236,6 +239,7 @@ class Linear1D_Row(ParallelModule): device: torch.device = None, process_group: ProcessGroup = None, seq_parallel: bool = False, + seq_parallel_dim: int = 1, parallel_input: bool = True, skip_bias_add: bool = False, weight: Optional[Parameter] = None, @@ -254,6 +258,7 @@ class Linear1D_Row(ParallelModule): self.skip_bias_add = skip_bias_add self.process_group = process_group self.seq_parallel = seq_parallel + self.seq_parallel_dim = seq_parallel_dim self.num_partitions = dist.get_world_size(self.process_group) if skip_bias_add and not bias: @@ -390,7 +395,8 @@ class Linear1D_Row(ParallelModule): else: output_parallel = F.linear(input_, self.weight) if self.seq_parallel: - output = linear_reducescatter_forward_gather_backward(output_parallel, self.process_group, 1) + output = linear_reducescatter_forward_gather_backward(output_parallel, self.process_group, + self.seq_parallel_dim) else: output = reduce_forward(output_parallel, self.process_group) diff --git a/colossalai/shardformer/modeling/chatglm2.py b/colossalai/shardformer/modeling/chatglm2.py index 409e2e1f5..16dcf87c8 100644 --- a/colossalai/shardformer/modeling/chatglm2.py +++ b/colossalai/shardformer/modeling/chatglm2.py @@ -9,6 +9,8 @@ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutpu from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer import ShardConfig +from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ( ChatGLMForConditionalGeneration, @@ -146,6 +148,7 @@ class ChatGLMPipelineForwards: stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, ): logger = logging.get_logger(__name__) output_hidden_states = (output_hidden_states @@ -198,6 +201,11 @@ class ChatGLMPipelineForwards: all_self_attentions = None all_hidden_states = () if output_hidden_states else None start_idx, end_idx = stage_index[0], stage_index[1] + + if shard_config.enable_sequence_parallelism: + hidden_states = split_forward_gather_backward(hidden_states, + dim=0, + process_group=shard_config.tensor_parallel_process_group) for idx in range(start_idx, end_idx): layer = self.encoder._get_layer(idx) if output_hidden_states: @@ -214,6 +222,11 @@ class ChatGLMPipelineForwards: hidden_states, kv_cache = layer_ret if use_cache: presents = presents + (kv_cache,) + + if shard_config.enable_sequence_parallelism: + hidden_states = gather_forward_split_backward(hidden_states, + dim=0, + process_group=shard_config.tensor_parallel_process_group) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if stage_manager.is_last_stage(): @@ -233,23 +246,22 @@ class ChatGLMPipelineForwards: return {'hidden_states': hidden_states} @staticmethod - def chatglm_for_conditional_generation_forward( - self: ChatGLMForConditionalGeneration, - input_ids: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[Tuple[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - return_last_logit: Optional[bool] = False, - stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None, - ): + def chatglm_for_conditional_generation_forward(self: ChatGLMForConditionalGeneration, + input_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + return_last_logit: Optional[bool] = False, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None): logger = logging.get_logger(__name__) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = (return_dict if return_dict is not None else self.config.use_return_dict) @@ -266,6 +278,7 @@ class ChatGLMPipelineForwards: stage_manager=stage_manager, hidden_states=hidden_states, stage_index=stage_index, + shard_config=shard_config, ) if stage_manager.is_last_stage(): hidden_states = transformer_outputs[0] @@ -296,3 +309,91 @@ class ChatGLMPipelineForwards: ) else: return transformer_outputs + + +def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig): + + def forward( + self, + input_ids, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.BoolTensor] = None, + full_attention_mask: Optional[torch.BoolTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = (return_dict if return_dict is not None else self.config.use_return_dict) + + batch_size, seq_length = input_ids.shape + + if inputs_embeds is None: + inputs_embeds = self.embedding(input_ids) + + if self.pre_seq_len is not None: + if past_key_values is None: + past_key_values = self.get_prompt( + batch_size=batch_size, + device=input_ids.device, + dtype=inputs_embeds.dtype, + ) + if attention_mask is not None: + attention_mask = torch.cat( + [ + attention_mask.new_ones((batch_size, self.pre_seq_len)), + attention_mask, + ], + dim=-1, + ) + + if full_attention_mask is None: + if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1): + full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask) + + # Rotary positional embeddings + rotary_pos_emb = self.rotary_pos_emb(self.seq_length) + if position_ids is not None: + rotary_pos_emb = rotary_pos_emb[position_ids] + else: + rotary_pos_emb = rotary_pos_emb[None, :seq_length] + rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous() + + # Run encoder. + # [seq_len, batch_size, hidden_size] -> [seq_len/TP_size, batch_size, hidden_size] + inputs_embeds = split_forward_gather_backward(inputs_embeds, + dim=0, + process_group=shard_config.tensor_parallel_process_group) + hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder( + inputs_embeds, + full_attention_mask, + rotary_pos_emb=rotary_pos_emb, + kv_caches=past_key_values, + use_cache=use_cache, + output_hidden_states=output_hidden_states, + ) + + hidden_states = gather_forward_split_backward(hidden_states, + dim=0, + process_group=shard_config.tensor_parallel_process_group) + + if not return_dict: + return tuple(v for v in [ + hidden_states, + presents, + all_hidden_states, + all_self_attentions, + ] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + return forward diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index fe091c658..19dd95fd6 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -155,20 +155,26 @@ class BertPolicy(Policy): # use flash attention if self.shard_config.enable_flash_attention: - policy[BertSelfAttention] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ 'forward': get_bert_flash_attention_forward(), - }) + }, + policy=policy, + target_key=BertSelfAttention) # use jit operator if self.shard_config.enable_jit_fused: - policy[BertSelfOutput] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ 'forward': get_jit_fused_bert_self_output_forward(), 'dropout_add': get_jit_fused_dropout_add_func(), - }) - policy[BertOutput] = ModulePolicyDescription(method_replacement={ + }, + policy=policy, + target_key=BertSelfOutput) + self.append_or_create_method_replacement(description={ 'forward': get_jit_fused_bert_output_forward(), 'dropout_add': get_jit_fused_dropout_add_func(), - }) + }, + policy=policy, + target_key=BertOutput) return policy diff --git a/colossalai/shardformer/policies/blip2.py b/colossalai/shardformer/policies/blip2.py index 3610e2c41..2e5388ab0 100644 --- a/colossalai/shardformer/policies/blip2.py +++ b/colossalai/shardformer/policies/blip2.py @@ -285,21 +285,26 @@ class BlipPolicy(Policy): # use flash attention if self.shard_config.enable_flash_attention: - policy[Blip2Attention] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ 'forward': get_blip2_flash_attention_forward(), - }) + }, + policy=policy, + target_key=Blip2Attention) # use jit operator if self.shard_config.enable_jit_fused: - policy[Blip2QFormerSelfOutput] = ModulePolicyDescription( - method_replacement={ - 'forward': get_jit_fused_blip2_QFormer_self_output_forward(), - 'dropout_add': get_jit_fused_dropout_add_func(), - }) - policy[Blip2QFormerOutput] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ + 'forward': get_jit_fused_blip2_QFormer_self_output_forward(), + 'dropout_add': get_jit_fused_dropout_add_func(), + }, + policy=policy, + target_key=Blip2QFormerSelfOutput) + self.append_or_create_method_replacement(description={ 'forward': get_jit_fused_blip2_QFormer_output_forward(), 'dropout_add': get_jit_fused_dropout_add_func(), - }) + }, + policy=policy, + target_key=Blip2QFormerOutput) return policy diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index 2727272d0..21db13f6e 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -125,25 +125,33 @@ class BloomPolicy(Policy): target_key=BloomModel) if self.shard_config.enable_flash_attention: - policy[BloomAttention] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ 'forward': get_bloom_flash_attention_forward(), - 'dropout_add': get_dropout_add_func() - }) + 'dropout_add': get_dropout_add_func(), + }, + policy=policy, + target_key=BloomAttention) # enable jit fused operator if self.shard_config.enable_jit_fused: - policy[BloomAttention] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ 'forward': get_jit_fused_bloom_attention_forward(), 'dropout_add': get_jit_fused_dropout_add_func(), - }) - policy[BloomMLP] = ModulePolicyDescription(method_replacement={ + }, + policy=policy, + target_key=BloomAttention) + self.append_or_create_method_replacement(description={ 'forward': get_jit_fused_bloom_mlp_forward(), 'dropout_add': get_jit_fused_dropout_add_func(), - }) - policy[BloomGelu] = ModulePolicyDescription(method_replacement={ + }, + policy=policy, + target_key=BloomMLP) + self.append_or_create_method_replacement(description={ 'forward': get_jit_fused_bloom_gelu_forward(), 'bloom_gelu_forward': get_jit_fused_gelu_forward_func(), - }) + }, + policy=policy, + target_key=BloomGelu) return policy diff --git a/colossalai/shardformer/policies/chatglm2.py b/colossalai/shardformer/policies/chatglm2.py index a15aa856d..b0d684a67 100644 --- a/colossalai/shardformer/policies/chatglm2.py +++ b/colossalai/shardformer/policies/chatglm2.py @@ -15,7 +15,11 @@ from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ( GLMBlock, ) -from ..modeling.chatglm2 import get_flash_core_attention_forward, get_jit_fused_glm_block_forward +from ..modeling.chatglm2 import ( + get_chatglm_sequence_parallel_forward_fn, + get_flash_core_attention_forward, + get_jit_fused_glm_block_forward, +) from ..modeling.jit import get_jit_fused_dropout_add_func from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription @@ -45,8 +49,8 @@ class ChatGLMPolicy(Policy): policy = {} + use_sequence_parallel = self.shard_config.enable_sequence_parallelism if self.shard_config.enable_tensor_parallelism: - policy[ChatGLMModel] = ModulePolicyDescription(attribute_replacement={}, sub_module_replacement=[ SubModuleReplacementDescription( @@ -55,36 +59,42 @@ class ChatGLMPolicy(Policy): ) ]) - policy[GLMBlock] = ModulePolicyDescription(attribute_replacement={ - "self_attention.num_attention_heads_per_partition": - self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, - "self_attention.projection_size": - (self.model.config.kv_channels * self.model.config.num_attention_heads) // - self.shard_config.tensor_parallel_size, - "self_attention.qkv_hidden_size": - (self.model.config.kv_channels * self.model.config.num_attention_heads * 3) // - self.shard_config.tensor_parallel_size, - "self_attention.core_attention.num_attention_heads_per_partition": - self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, - "self_attention.core_attention.hidden_size_per_partition": - self.model.config.kv_channels * self.model.config.num_attention_heads // - self.shard_config.tensor_parallel_size, - }, - param_replacement=[], - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="self_attention.query_key_value", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="self_attention.dense", - target_module=col_nn.Linear1D_Row, - ), - SubModuleReplacementDescription( - suffix="self_attention.core_attention.attention_dropout", - target_module=col_nn.DropoutForParallelInput, - ), - ]) + policy[GLMBlock] = ModulePolicyDescription( + attribute_replacement={ + "self_attention.num_attention_heads_per_partition": + self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + "self_attention.projection_size": + (self.model.config.kv_channels * self.model.config.num_attention_heads) // + self.shard_config.tensor_parallel_size, + "self_attention.qkv_hidden_size": + (self.model.config.kv_channels * self.model.config.num_attention_heads * 3) // + self.shard_config.tensor_parallel_size, + "self_attention.core_attention.num_attention_heads_per_partition": + self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + "self_attention.core_attention.hidden_size_per_partition": + self.model.config.kv_channels * self.model.config.num_attention_heads // + self.shard_config.tensor_parallel_size, + }, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription(suffix="self_attention.query_key_value", + target_module=col_nn.Linear1D_Col, + kwargs={ + 'seq_parallel': use_sequence_parallel, + 'seq_parallel_dim': 0 + }), + SubModuleReplacementDescription(suffix="self_attention.dense", + target_module=col_nn.Linear1D_Row, + kwargs={ + 'seq_parallel': use_sequence_parallel, + 'seq_parallel_dim': 0 + }), + SubModuleReplacementDescription( + suffix="self_attention.core_attention.attention_dropout", + target_module=col_nn.DropoutForParallelInput, + ), + ]) + # optimization configuration if self.shard_config.enable_fused_normalization: if not self.model.config.rmsnorm: @@ -124,16 +134,27 @@ class ChatGLMPolicy(Policy): # use flash attention if self.shard_config.enable_flash_attention: - policy[CoreAttention] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ 'forward': get_flash_core_attention_forward(), - }) + }, + policy=policy, + target_key=CoreAttention) + + # use sequence parallel + if use_sequence_parallel: + self.append_or_create_method_replacement( + description={'forward': get_chatglm_sequence_parallel_forward_fn(self.shard_config)}, + policy=policy, + target_key=ChatGLMModel) # use jit fused operator if self.shard_config.enable_jit_fused: - policy[GLMBlock] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ 'forward': get_jit_fused_glm_block_forward(), 'dropout_add': get_jit_fused_dropout_add_func(), - }) + }, + policy=policy, + target_key=GLMBlock) return policy @@ -178,7 +199,13 @@ class ChatGLMPolicy(Policy): layers_per_stage = Policy.distribute_layers(module.num_layers, stage_manager.num_stages) stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) - method_replacement = {'forward': partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} + method_replacement = { + 'forward': + partial(new_forward, + stage_manager=stage_manager, + stage_index=stage_index, + shard_config=self.shard_config) + } self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index d34c0ae9f..acae26309 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -118,9 +118,11 @@ class GPT2Policy(Policy): target_key=GPT2Block) if self.shard_config.enable_flash_attention: - policy[GPT2Attention] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ 'forward': get_gpt2_flash_attention_forward(), - }) + }, + policy=policy, + target_key=GPT2Attention) if self.shard_config.enable_sequence_parallelism: policy[GPT2Model].method_replacement = {"forward": gpt2_sequence_parallel_forward_fn(self.shard_config)} diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 5ee95f3be..ccf776407 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -105,9 +105,11 @@ class LlamaPolicy(Policy): target_key=LlamaModel) if self.shard_config.enable_flash_attention: - policy[LlamaAttention] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ 'forward': get_llama_flash_attention_forward(), - }) + }, + policy=policy, + target_key=LlamaAttention) return policy diff --git a/colossalai/shardformer/policies/sam.py b/colossalai/shardformer/policies/sam.py index b1eba0432..9753d5a73 100644 --- a/colossalai/shardformer/policies/sam.py +++ b/colossalai/shardformer/policies/sam.py @@ -199,12 +199,16 @@ class SamPolicy(Policy): # use flash attention if self.shard_config.enable_flash_attention: - policy[SamAttention] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ 'forward': get_sam_flash_attention_forward(), - }) - policy[SamVisionAttention] = ModulePolicyDescription(method_replacement={ + }, + policy=policy, + target_key=SamAttention) + self.append_or_create_method_replacement(description={ 'forward': get_sam_vision_flash_attention_forward(), - }) + }, + policy=policy, + target_key=SamVisionAttention) return policy diff --git a/colossalai/shardformer/policies/vit.py b/colossalai/shardformer/policies/vit.py index 617720ee7..757bab95f 100644 --- a/colossalai/shardformer/policies/vit.py +++ b/colossalai/shardformer/policies/vit.py @@ -90,16 +90,20 @@ class ViTPolicy(Policy): # use flash attention if self.shard_config.enable_flash_attention: - policy[ViTSelfAttention] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ 'forward': get_vit_flash_self_attention_forward(), - }) + }, + policy=policy, + target_key=ViTSelfAttention) # use jit fused operator if self.shard_config.enable_jit_fused: - policy[ViTOutput] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ 'forward': get_jit_fused_vit_output_forward(), 'dropout_add': get_jit_fused_dropout_add_func(), - }) + }, + policy=policy, + target_key=ViTOutput) return policy def new_model_class(self): diff --git a/tests/kit/model_zoo/transformers/chatglm2.py b/tests/kit/model_zoo/transformers/chatglm2.py index c6473ee2a..d543df00b 100644 --- a/tests/kit/model_zoo/transformers/chatglm2.py +++ b/tests/kit/model_zoo/transformers/chatglm2.py @@ -12,8 +12,8 @@ from ..registry import ModelAttribute, model_zoo def data_gen(): - input_ids = torch.tensor([[5941, 15, 2670, 3543, 632, 2075]], dtype=torch.int64) - attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1]]) + input_ids = torch.tensor([[5941, 15, 2670, 3543, 632, 2075, 632, 2075]], dtype=torch.int64) + attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]]) return dict(input_ids=input_ids, attention_mask=attention_mask) From e04436a82aa847db166cb181053c290c8a150496 Mon Sep 17 00:00:00 2001 From: Jianghai <72591262+CjhHa1@users.noreply.github.com> Date: Wed, 23 Aug 2023 15:05:24 +0800 Subject: [PATCH 13/33] [shardformer] tests for 3d parallel (#4493) --- tests/test_shardformer/test_model/_utils.py | 1 - .../test_model/test_shard_bert.py | 36 ++++++++++++++++ .../test_model/test_shard_bloom.py | 38 +++++++++++++++++ .../test_model/test_shard_chatglm2.py | 35 ++++++++++++++++ .../test_model/test_shard_gpt2.py | 36 ++++++++++++++++ .../test_model/test_shard_llama.py | 37 ++++++++++++++++- .../test_model/test_shard_opt.py | 35 ++++++++++++++++ .../test_model/test_shard_t5.py | 35 ++++++++++++++++ .../test_model/test_shard_vit.py | 35 ++++++++++++++++ .../test_model/test_shard_whisper.py | 41 +++++++++++++++++-- 10 files changed, 324 insertions(+), 5 deletions(-) diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index 789b3b24e..811471bec 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -245,7 +245,6 @@ def check_grad(org_model: Module, org_grad = getattr_(org_model, suffix).weight.grad shard_grad = getattr_(sharded_model, suffix).weight.grad shard_weight = getattr_(sharded_model, suffix).weight - if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): shard_grad_list = [torch.zeros_like(shard_grad).to('cuda') for _ in range(dist.get_world_size(tp_group))] dist.all_gather(shard_grad_list, shard_grad, tp_group) diff --git a/tests/test_shardformer/test_model/test_shard_bert.py b/tests/test_shardformer/test_model/test_shard_bert.py index c96701704..76f8c0541 100644 --- a/tests/test_shardformer/test_model/test_shard_bert.py +++ b/tests/test_shardformer/test_model/test_shard_bert.py @@ -120,12 +120,40 @@ def run_bert_test(test_config): torch.cuda.empty_cache() +@parameterize('test_config', [ + { + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 4, + 'enable_all_optimization': False, + 'use_lazy_init': False, + 'precision': 'fp32', + 'initial_scale': 1, + }, +]) +def run_bert_3d_test(test_config): + sub_model_zoo = model_zoo.get_sub_registry('transformers_bert') + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() + Randomizer.reset_index() + torch.cuda.empty_cache() + + def check_bert(rank, world_size, port): disable_existing_loggers() colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') run_bert_test() +def check_bert_3d(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_bert_3d_test() + + @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() @@ -133,5 +161,13 @@ def test_bert(): spawn(check_bert, 4) +@pytest.mark.largedist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_bert_3d(): + spawn(check_bert_3d, 8) + + if __name__ == "__main__": test_bert() + test_bert_3d() diff --git a/tests/test_shardformer/test_model/test_shard_bloom.py b/tests/test_shardformer/test_model/test_shard_bloom.py index bd87be8b7..0e236fd47 100644 --- a/tests/test_shardformer/test_model/test_shard_bloom.py +++ b/tests/test_shardformer/test_model/test_shard_bloom.py @@ -3,6 +3,7 @@ import torch import colossalai from colossalai.logging import disable_existing_loggers +from colossalai.shardformer.layer.utils import Randomizer from colossalai.tensor.d_tensor.api import clear_layout_converter from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn from tests.kit.model_zoo import model_zoo @@ -118,6 +119,29 @@ def run_bloom_test(test_config): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) clear_layout_converter() + Randomizer.reset_index() + torch.cuda.empty_cache() + + +@parameterize('test_config', [ + { + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 4, + 'enable_all_optimization': False, + 'use_lazy_init': False, + 'precision': 'fp32', + 'initial_scale': 1, + }, +]) +def run_bloom_3d_test(test_config): + sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom') + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() + Randomizer.reset_index() torch.cuda.empty_cache() @@ -127,6 +151,12 @@ def check_bloom(rank, world_size, port): run_bloom_test() +def check_bloom_3d(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_bloom_3d_test() + + @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() @@ -134,5 +164,13 @@ def test_bloom(): spawn(check_bloom, 4) +@pytest.mark.largedist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_bloom_3d(): + spawn(check_bloom_3d, 8) + + if __name__ == "__main__": test_bloom() + test_bloom_3d() diff --git a/tests/test_shardformer/test_model/test_shard_chatglm2.py b/tests/test_shardformer/test_model/test_shard_chatglm2.py index 64732e06b..a8957d8d3 100644 --- a/tests/test_shardformer/test_model/test_shard_chatglm2.py +++ b/tests/test_shardformer/test_model/test_shard_chatglm2.py @@ -145,12 +145,39 @@ def run_chatglm_test(test_config): torch.cuda.empty_cache() +@parameterize('test_config', [ + { + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 4, + 'enable_all_optimization': False, + 'use_lazy_init': False, + 'precision': 'fp32', + 'initial_scale': 1, + }, +]) +def run_chatglm_3d_test(test_config): + sub_model_zoo = model_zoo.get_sub_registry('transformers_chatglm') + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() + torch.cuda.empty_cache() + + def check_chatglm(rank, world_size, port): disable_existing_loggers() colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') run_chatglm_test() +def check_chatglm_3d(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_chatglm_3d_test() + + @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() @@ -158,5 +185,13 @@ def test_chatglm(): spawn(check_chatglm, 4) +@pytest.mark.largedist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_chatglm_3d(): + spawn(check_chatglm_3d, 8) + + if __name__ == "__main__": test_chatglm() + test_chatglm_3d() diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index c776a80d8..85d66e493 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -141,12 +141,40 @@ def run_gpt2_test(test_config): torch.cuda.empty_cache() +@parameterize('test_config', [ + { + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 4, + 'enable_all_optimization': False, + 'use_lazy_init': False, + 'precision': 'fp32', + 'initial_scale': 1, + }, +]) +@clear_cache_before_run() +def run_gpt2_3d_test(test_config): + sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt') + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() + torch.cuda.empty_cache() + + def check_gpt2(rank, world_size, port): disable_existing_loggers() colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') run_gpt2_test() +def check_gpt2_3d(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_gpt2_3d_test() + + @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() @@ -154,5 +182,13 @@ def test_gpt2(): spawn(check_gpt2, 4) +@pytest.mark.largedist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_gpt2_3d(): + spawn(check_gpt2_3d, 8) + + if __name__ == "__main__": test_gpt2() + test_gpt2_3d() diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 7140c4666..485d2685e 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -56,7 +56,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # unwrap model llama_model = unwrap_model(org_model, 'LlamaModel', 'model') shard_llama_model = unwrap_model(sharded_model, 'LlamaModel', 'model') - # check grad row_layer_for_check = ['layers[0].self_attn.q_proj', 'embed_tokens'] col_layer_for_check = ['layers[0].self_attn.o_proj'] @@ -156,12 +155,40 @@ def run_llama_test(test_config): torch.cuda.empty_cache() +@parameterize('test_config', [ + { + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 4, + 'enable_all_optimization': False, + 'use_lazy_init': False, + 'precision': 'fp32', + 'initial_scale': 1, + }, +]) +def run_llama_3d_test(test_config): + sub_model_zoo = model_zoo.get_sub_registry('transformers_llama') + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() + Randomizer.reset_index() + torch.cuda.empty_cache() + + def check_llama(rank, world_size, port): disable_existing_loggers() colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') run_llama_test() +def check_llama_3d(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_llama_3d_test() + + @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() @@ -169,5 +196,13 @@ def test_llama(): spawn(check_llama, 4) +@pytest.mark.largedist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_llama_3d(): + spawn(check_llama_3d, 8) + + if __name__ == "__main__": test_llama() + test_llama_3d() diff --git a/tests/test_shardformer/test_model/test_shard_opt.py b/tests/test_shardformer/test_model/test_shard_opt.py index e6faafdae..ad344585e 100644 --- a/tests/test_shardformer/test_model/test_shard_opt.py +++ b/tests/test_shardformer/test_model/test_shard_opt.py @@ -146,12 +146,39 @@ def run_opt_test(test_config): torch.cuda.empty_cache() +@parameterize('test_config', [ + { + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 4, + 'enable_all_optimization': False, + 'use_lazy_init': False, + 'precision': 'fp32', + 'initial_scale': 1, + }, +]) +def run_opt_3d_test(test_config): + sub_model_zoo = model_zoo.get_sub_registry('transformers_opt') + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() + torch.cuda.empty_cache() + + def check_OPTModel(rank, world_size, port): disable_existing_loggers() colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') run_opt_test() +def check_opt_3d(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_opt_3d_test() + + @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() @@ -159,5 +186,13 @@ def test_OPTModel(): spawn(check_OPTModel, 4) +@pytest.mark.largedist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_opt_3d(): + spawn(check_opt_3d, 8) + + if __name__ == '__main__': test_OPTModel() + test_opt_3d() diff --git a/tests/test_shardformer/test_model/test_shard_t5.py b/tests/test_shardformer/test_model/test_shard_t5.py index 599f5a80d..a853f024d 100644 --- a/tests/test_shardformer/test_model/test_shard_t5.py +++ b/tests/test_shardformer/test_model/test_shard_t5.py @@ -137,12 +137,39 @@ def run_t5_test(test_config): torch.cuda.empty_cache() +@parameterize('test_config', [ + { + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 4, + 'enable_all_optimization': False, + 'use_lazy_init': False, + 'precision': 'fp32', + 'initial_scale': 1, + }, +]) +def run_t5_3d_test(test_config): + sub_model_zoo = model_zoo.get_sub_registry('transformers_t5') + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() + torch.cuda.empty_cache() + + def check_t5(rank, world_size, port): disable_existing_loggers() colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') run_t5_test() +def check_t5_3d(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_t5_3d_test() + + @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() @@ -150,5 +177,13 @@ def test_t5(): spawn(check_t5, 4) +@pytest.mark.largedist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_t5_3d(): + spawn(check_t5_3d, 8) + + if __name__ == "__main__": test_t5() + test_t5_3d() diff --git a/tests/test_shardformer/test_model/test_shard_vit.py b/tests/test_shardformer/test_model/test_shard_vit.py index b27add24c..0b092966c 100644 --- a/tests/test_shardformer/test_model/test_shard_vit.py +++ b/tests/test_shardformer/test_model/test_shard_vit.py @@ -146,12 +146,39 @@ def run_vit_test(test_config): torch.cuda.empty_cache() +@parameterize('test_config', [ + { + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 4, + 'enable_all_optimization': False, + 'use_lazy_init': False, + 'precision': 'fp32', + 'initial_scale': 1, + }, +]) +def run_vit_3d_test(test_config): + sub_model_zoo = model_zoo.get_sub_registry('transformers_vit') + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() + torch.cuda.empty_cache() + + def check_vit(rank, world_size, port): disable_existing_loggers() colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') run_vit_test() +def check_vit_3d(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_vit_3d_test() + + @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() @@ -159,5 +186,13 @@ def test_vit(): spawn(check_vit, 4) +@pytest.mark.largedist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_vit_3d(): + spawn(check_vit_3d, 8) + + if __name__ == "__main__": test_vit() + test_vit_3d() diff --git a/tests/test_shardformer/test_model/test_shard_whisper.py b/tests/test_shardformer/test_model/test_shard_whisper.py index 90e007e34..6445b314d 100644 --- a/tests/test_shardformer/test_model/test_shard_whisper.py +++ b/tests/test_shardformer/test_model/test_shard_whisper.py @@ -82,8 +82,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, atol, rtol = 5e-3, 5e-3 if stage_manager is None or stage_manager.is_first_stage(): - check_grad(whisper, sharded_whisper, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0) - check_grad(whisper, sharded_whisper, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1) + check_grad(whisper, sharded_whisper, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1) + check_grad(whisper, sharded_whisper, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0) # check weights after optimizer.step() org_optimizer.step() @@ -99,7 +99,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, tp_group, atol=atol, rtol=rtol, - dim=0, + dim=1, verbose=False) check_weight(whisper, sharded_whisper, @@ -155,12 +155,39 @@ def run_whisper_test(test_config): torch.cuda.empty_cache() +@parameterize('test_config', [ + { + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 4, + 'enable_all_optimization': False, + 'use_lazy_init': False, + 'precision': 'fp32', + 'initial_scale': 1, + }, +]) +def run_whisper_3d_test(test_config): + sub_model_zoo = model_zoo.get_sub_registry('transformers_whisper') + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() + torch.cuda.empty_cache() + + def check_whisper(rank, world_size, port): disable_existing_loggers() colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') run_whisper_test() +def check_whisper_3d(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_whisper_3d_test() + + @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() @@ -168,5 +195,13 @@ def test_whisper(): spawn(check_whisper, 4) +@pytest.mark.largedist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_whisper_3d(): + spawn(check_whisper_3d, 8) + + if __name__ == "__main__": test_whisper() + test_whisper_3d() From 3353e55c80d22c765314ca4f4886d61f0a58cdd7 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Thu, 24 Aug 2023 15:50:02 +0800 Subject: [PATCH 14/33] [shardformer] vit/llama/t5 ignore the sequence parallelism flag and some fix. (#4498) * [shardformer] chatglm support sequence parallel [shardformer] chatglm support sequence parallel [shardformer] chatglm support sequence parallel [shardformer] chatglm support sequence parallel [shardformer] chatglm support sequence parallel [shardformer] chatglm support sequence parallel * fix fix fix fix * [shardformer] jit fused fix * [shardformer] jit fused fix * [shardformer] jit fused fix * [shardformer] jit fused fix * [shardformer] jit fused fix * [shardformer] jit fused fix * [shardformer] jit fused fix * activate checks --- colossalai/shardformer/modeling/bert.py | 6 +++++ colossalai/shardformer/policies/llama.py | 5 ++++ colossalai/shardformer/policies/opt.py | 12 ++++++--- colossalai/shardformer/policies/t5.py | 5 ++++ colossalai/shardformer/policies/vit.py | 5 ++++ colossalai/shardformer/policies/whisper.py | 27 +++++++++---------- .../test_model/test_shard_whisper.py | 7 ++--- 7 files changed, 46 insertions(+), 21 deletions(-) diff --git a/colossalai/shardformer/modeling/bert.py b/colossalai/shardformer/modeling/bert.py index d88661953..30855a622 100644 --- a/colossalai/shardformer/modeling/bert.py +++ b/colossalai/shardformer/modeling/bert.py @@ -187,6 +187,9 @@ class BertPipelineForwards: hidden_states = split_forward_gather_backward(hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group) + if encoder_hidden_states is not None: + encoder_hidden_states = split_forward_gather_backward( + encoder_hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group) for idx, encoder_layer in enumerate(self.encoder.layer[start_idx:end_idx], start=start_idx): if stage_manager.is_first_stage() and idx == 0: @@ -1241,6 +1244,9 @@ def bert_sequence_parallel_forward_fn(shard_config: ShardConfig): embedding_output = split_forward_gather_backward(embedding_output, dim=1, process_group=shard_config.tensor_parallel_process_group) + if encoder_hidden_states is not None: + encoder_hidden_states = split_forward_gather_backward( + encoder_hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group) encoder_outputs = self.encoder( embedding_output, diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index ccf776407..c417e5d01 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -1,3 +1,4 @@ +import warnings from functools import partial from typing import Callable, Dict, List, Union @@ -35,6 +36,10 @@ class LlamaPolicy(Policy): policy = {} + if self.shard_config.enable_sequence_parallelism: + self.shard_config.enable_sequence_parallelism = False + warnings.warn("Llama dosen't support sequence parallelism now, will ignore the sequence parallelism flag.") + if self.shard_config.enable_tensor_parallelism: policy[LlamaDecoderLayer] = ModulePolicyDescription( attribute_replacement={ diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index 58663553b..abe491bfa 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -104,16 +104,20 @@ class OPTPolicy(Policy): # use flash attention if self.shard_config.enable_flash_attention: - policy[OPTAttention] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ 'forward': get_opt_flash_attention_forward(), - }) + }, + policy=policy, + target_key=OPTAttention) # use jit fused operator if self.shard_config.enable_jit_fused: - policy[OPTDecoderLayer] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ 'forward': get_jit_fused_opt_decoder_layer_forward(), 'dropout_add': get_jit_fused_dropout_add_func(), - }) + }, + policy=policy, + target_key=OPTDecoderLayer) return policy diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py index 651883d35..192a1b847 100644 --- a/colossalai/shardformer/policies/t5.py +++ b/colossalai/shardformer/policies/t5.py @@ -1,3 +1,4 @@ +import warnings from functools import partial from typing import Callable, Dict, List, Optional, Tuple @@ -59,6 +60,10 @@ class T5BasePolicy(Policy): policy = {} + if self.shard_config.enable_sequence_parallelism: + self.shard_config.enable_sequence_parallelism = False + warnings.warn("T5 dosen't support sequence parallelism now, will ignore the sequence parallelism flag.") + if self.shard_config.enable_tensor_parallelism: policy[T5Stack] = ModulePolicyDescription(sub_module_replacement=[ SubModuleReplacementDescription( diff --git a/colossalai/shardformer/policies/vit.py b/colossalai/shardformer/policies/vit.py index 757bab95f..b4fb8692e 100644 --- a/colossalai/shardformer/policies/vit.py +++ b/colossalai/shardformer/policies/vit.py @@ -1,3 +1,4 @@ +import warnings from typing import Callable, Dict, List, Union import torch.nn as nn @@ -32,6 +33,10 @@ class ViTPolicy(Policy): policy = {} + if self.shard_config.enable_sequence_parallelism: + self.shard_config.enable_sequence_parallelism = False + warnings.warn("Vit dosen't support sequence parallelism now, will ignore the sequence parallelism flag.") + if self.shard_config.enable_tensor_parallelism: policy[ViTEmbeddings] = ModulePolicyDescription(attribute_replacement={}, param_replacement=[], diff --git a/colossalai/shardformer/policies/whisper.py b/colossalai/shardformer/policies/whisper.py index a33f929f1..bffb624d0 100644 --- a/colossalai/shardformer/policies/whisper.py +++ b/colossalai/shardformer/policies/whisper.py @@ -1,3 +1,4 @@ +import warnings from functools import partial from typing import Callable, Dict, List, Tuple @@ -33,7 +34,6 @@ class WhisperPolicy(Policy): r""" Reshape the Embedding layer to make the embedding dimension divisible by world_size """ - # TODO: vocab_size = self.model.config.vocab_size world_size = self.shard_config.tensor_parallel_size if vocab_size % world_size != 0: @@ -52,6 +52,14 @@ class WhisperPolicy(Policy): policy = {} + if self.shard_config.enable_sequence_parallelism: + self.shard_config.enable_sequence_parallelism = False + warnings.warn( + "Whisper dosen't support sequence parallelism now, will ignore the sequence parallelism flag.") + if self.shard_config.enable_jit_fused: + self.shard_config.enable_jit_fused = False + warnings.warn("Whisper dosen't support jit fused operator now, will ignore the jit fused flag.") + if self.shard_config.enable_tensor_parallelism: policy[WhisperEncoderLayer] = ModulePolicyDescription(attribute_replacement={ "self_attn.embed_dim": @@ -198,20 +206,11 @@ class WhisperPolicy(Policy): # enable flash attention if self.shard_config.enable_flash_attention: - policy[WhisperAttention] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ 'forward': get_whisper_flash_attention_forward(), - }) - - # use jit fused operator - if self.shard_config.enable_jit_fused: - policy[WhisperEncoderLayer] = ModulePolicyDescription(method_replacement={ - 'forward': get_jit_fused_whisper_encoder_layer_forward(), - 'dropout_add': get_jit_fused_dropout_add_func(), - }) - policy[WhisperDecoderLayer] = ModulePolicyDescription(method_replacement={ - 'forward': get_jit_fused_whisper_decoder_layer_forward(), - 'dropout_add': get_jit_fused_dropout_add_func(), - }) + }, + policy=policy, + target_key=WhisperAttention) return policy diff --git a/tests/test_shardformer/test_model/test_shard_whisper.py b/tests/test_shardformer/test_model/test_shard_whisper.py index 6445b314d..011fb8d23 100644 --- a/tests/test_shardformer/test_model/test_shard_whisper.py +++ b/tests/test_shardformer/test_model/test_shard_whisper.py @@ -44,7 +44,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # check last hidden state & loss if stage_manager is None or stage_manager.is_last_stage(): if test_config['precision'] == 'fp32': - atol, rtol = 1e-3, 1e-3 + atol, rtol = 2e-4, 2e-4 else: atol, rtol = 5e-3, 5e-3 @@ -77,7 +77,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # check weights and gradients if test_config['precision'] == 'fp32': - atol, rtol = 1e-3, 1e-3 + atol, rtol = 2e-4, 2e-4 else: atol, rtol = 5e-3, 5e-3 @@ -89,7 +89,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, org_optimizer.step() sharded_optimizer.step() if test_config['precision'] == 'fp32': - atol, rtol = 1e-3, 1e-3 + atol, rtol = 2e-4, 2e-4 else: atol, rtol = 5e-3, 5e-3 if stage_manager is None or stage_manager.is_first_stage(): @@ -114,6 +114,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # TODO(jianghai) fix fp16 +#TODO fix WhisperForConditionalGeneration enable jit fused operator @parameterize('test_config', [{ 'tp_size': 2, 'pp_size': 2, From de8a65babcf3bdf50fd1a60ff0baabe3e4f7803e Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Fri, 25 Aug 2023 19:41:24 +0800 Subject: [PATCH 15/33] [shardformer] opt fix. (#4514) * [shardformer] chatglm support sequence parallel [shardformer] chatglm support sequence parallel [shardformer] chatglm support sequence parallel [shardformer] chatglm support sequence parallel [shardformer] chatglm support sequence parallel [shardformer] chatglm support sequence parallel * fix fix fix fix * [shardformer] jit fused fix * [shardformer] jit fused fix * [shardformer] jit fused fix * [shardformer] jit fused fix * [shardformer] jit fused fix * [shardformer] jit fused fix * [shardformer] jit fused fix * activate checks * [Test] test ci * test ci * test ci * test ci * test ci * test ci * test ci * fix --- colossalai/shardformer/policies/opt.py | 26 +++++++++---------- .../test_model/test_shard_opt.py | 1 - .../test_model/test_shard_whisper.py | 2 +- 3 files changed, 14 insertions(+), 15 deletions(-) diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index abe491bfa..be9d1c58b 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -103,21 +103,21 @@ class OPTPolicy(Policy): target_key=OPTDecoderLayer) # use flash attention - if self.shard_config.enable_flash_attention: - self.append_or_create_method_replacement(description={ - 'forward': get_opt_flash_attention_forward(), - }, - policy=policy, - target_key=OPTAttention) + # if self.shard_config.enable_flash_attention: + # self.append_or_create_method_replacement(description={ + # 'forward': get_opt_flash_attention_forward(), + # }, + # policy=policy, + # target_key=OPTAttention) # use jit fused operator - if self.shard_config.enable_jit_fused: - self.append_or_create_method_replacement(description={ - 'forward': get_jit_fused_opt_decoder_layer_forward(), - 'dropout_add': get_jit_fused_dropout_add_func(), - }, - policy=policy, - target_key=OPTDecoderLayer) + # if self.shard_config.enable_jit_fused: + # self.append_or_create_method_replacement(description={ + # 'forward': get_jit_fused_opt_decoder_layer_forward(), + # 'dropout_add': get_jit_fused_dropout_add_func(), + # }, + # policy=policy, + # target_key=OPTDecoderLayer) return policy diff --git a/tests/test_shardformer/test_model/test_shard_opt.py b/tests/test_shardformer/test_model/test_shard_opt.py index ad344585e..71483b752 100644 --- a/tests/test_shardformer/test_model/test_shard_opt.py +++ b/tests/test_shardformer/test_model/test_shard_opt.py @@ -137,7 +137,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'initial_scale': 1 }]) def run_opt_test(test_config): - sub_model_zoo = model_zoo.get_sub_registry('transformers_opt') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) diff --git a/tests/test_shardformer/test_model/test_shard_whisper.py b/tests/test_shardformer/test_model/test_shard_whisper.py index 011fb8d23..6eaed7d37 100644 --- a/tests/test_shardformer/test_model/test_shard_whisper.py +++ b/tests/test_shardformer/test_model/test_shard_whisper.py @@ -89,7 +89,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, org_optimizer.step() sharded_optimizer.step() if test_config['precision'] == 'fp32': - atol, rtol = 2e-4, 2e-4 + atol, rtol = 5e-4, 5e-4 else: atol, rtol = 5e-3, 5e-3 if stage_manager is None or stage_manager.is_first_stage(): From 44eab2b27f8f854d5fb050a3b5aa83e79effd0b6 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Fri, 25 Aug 2023 22:04:57 +0800 Subject: [PATCH 16/33] [shardformer] support sharded checkpoint IO for models of HybridParallelPlugin (#4506) * add APIs * implement save_sharded_model * add test for hybrid checkpointio * implement naive loading for sharded model * implement efficient sharded model loading * open a new file for hybrid checkpoint_io * small fix * fix circular importing * fix docstring * arrange arguments and apis * small fix --- .../booster/plugin/hybrid_parallel_plugin.py | 5 +- colossalai/checkpoint_io/__init__.py | 3 +- .../hybrid_parallel_checkpoint_io.py | 316 ++++++++++++++++++ colossalai/checkpoint_io/utils.py | 60 +++- .../shardformer/layer/parallel_module.py | 9 +- colossalai/zero/gemini/gemini_ddp.py | 28 +- ...st_hybrid_parallel_plugin_checkpoint_io.py | 116 +++++++ 7 files changed, 497 insertions(+), 40 deletions(-) create mode 100644 colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py create mode 100644 tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 016323ae7..c49b3e182 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -16,7 +16,7 @@ from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOptimizer -from colossalai.checkpoint_io import CheckpointIO +from colossalai.checkpoint_io import CheckpointIO, HypridParallelCheckpointIO from colossalai.cluster import ProcessGroupMesh from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule @@ -292,6 +292,7 @@ class HybridParallelPlugin(PipelinePluginBase): self.schedule = OneForwardOneBackwardSchedule(num_microbatches, self.stage_manager) self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS) + self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS) self.shard_config = ShardConfig(tensor_parallel_process_group=self.tp_group, pipeline_stage_manager=self.stage_manager, enable_tensor_parallelism=self.tp_size > 1, @@ -460,7 +461,7 @@ class HybridParallelPlugin(PipelinePluginBase): **_kwargs) def get_checkpoint_io(self) -> CheckpointIO: - return None + return HypridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group) def no_sync(self, model: Module) -> Iterator[None]: raise NotImplementedError diff --git a/colossalai/checkpoint_io/__init__.py b/colossalai/checkpoint_io/__init__.py index c25048e25..07b1f81da 100644 --- a/colossalai/checkpoint_io/__init__.py +++ b/colossalai/checkpoint_io/__init__.py @@ -1,5 +1,6 @@ from .checkpoint_io_base import CheckpointIO from .general_checkpoint_io import GeneralCheckpointIO +from .hybrid_parallel_checkpoint_io import HypridParallelCheckpointIO from .index_file import CheckpointIndexFile -__all__ = ['CheckpointIO', 'CheckpointIndexFile', 'GeneralCheckpointIO'] +__all__ = ['CheckpointIO', 'CheckpointIndexFile', 'GeneralCheckpointIO', 'HybridParallelCheckpointIO'] diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py new file mode 100644 index 000000000..56a89bff7 --- /dev/null +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -0,0 +1,316 @@ +import copy +import gc +import logging +import os +from pathlib import Path +from shutil import rmtree +from typing import Any, Callable, Iterator, List, Optional, OrderedDict, Tuple, Union + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.distributed import ProcessGroup +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler as LRScheduler + +from colossalai.cluster import ProcessGroupMesh +from colossalai.tensor.d_tensor import ( + is_customized_distributed_tensor, + is_distributed_tensor, + to_global, + to_global_for_customized_distributed_tensor, +) + +from .general_checkpoint_io import GeneralCheckpointIO +from .index_file import CheckpointIndexFile +from .utils import ( + StateDictSharder, + calculate_tensor_size, + gather_distributed_param, + get_model_base_filenames, + get_optimizer_base_filenames, + get_shard_filename, + is_safetensors_available, + load_shard_state_dict, + load_state_dict_into_model, + save_param_groups, + save_state_dict, + save_state_dict_shards, +) + +try: + from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys +except ImportError: + _EXTRA_STATE_KEY_SUFFIX = '_extra_state' + + +class HypridParallelCheckpointIO(GeneralCheckpointIO): + """ + CheckpointIO for Hybrid Parallel Training. + + Args: + dp_group (ProcessGroup): Process group along data parallel dimension. + pp_group (ProcessGroup): Process group along pipeline parallel dimension. + tp_group (ProcessGroup): Process group along tensor parallel dimension. + """ + + def __init__(self, dp_group: ProcessGroup, pp_group: ProcessGroup, tp_group: ProcessGroup) -> None: + super().__init__() + self.dp_group = dp_group + self.pp_group = pp_group + self.tp_group = tp_group + self.dp_rank = dist.get_rank(self.dp_group) + self.tp_rank = dist.get_rank(self.tp_group) + self.pp_rank = dist.get_rank(self.pp_group) + self.dp_size = dist.get_world_size(dp_group) + self.pp_size = dist.get_world_size(pp_group) + self.tp_size = dist.get_world_size(tp_group) + + @staticmethod + def _model_sharder(model: nn.Module, + prefix: str = '', + keep_vars: bool = False, + size_per_shard: int = 1024) -> Iterator[Tuple[OrderedDict, int]]: + # An internel method that breaks state_dict of model into shards within limited size. + + state_dict_sharder = StateDictSharder(size_per_shard) + + # Save parameters. + for name, param in model.named_parameters(): + if param is None: + continue + # Gather tensor pieces when using tensor parallel. + param_ = gather_distributed_param(param, keep_vars=False) + block, block_size = state_dict_sharder.append(prefix + name, param_) + if block is not None: + yield block, block_size + + # Save buffers. + for name, buf in model.named_buffers(): + if buf is not None and name not in model._non_persistent_buffers_set: + buffer = buf if keep_vars else buf.detach() + block, block_size = state_dict_sharder.append(prefix + name, buffer) + if block is not None: + yield block, block_size + + # Save extra states. + extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX + if getattr(model.__class__, "get_extra_state", + torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state: + extra_state = model.get_extra_state() + block, block_size = state_dict_sharder.append(extra_state_key, extra_state) + if block is not None: + yield block, block_size + + # Return the last block in sharder. + yield state_dict_sharder.current_block, state_dict_sharder.current_block_size + + @staticmethod + def _optimizer_sharder(optimizer: Optimizer, size_per_shard: int = 1024): + # An internel method that breaks state_dict of optimizer into shards within limited size. + # TODO (Baizhou): Implement sharding feature of optimizer. + pass + + def save_sharded_model(self, + model: nn.Module, + checkpoint: str, + gather_dtensor: bool = True, + prefix: Optional[str] = None, + size_per_shard: int = 1024, + use_safetensors: bool = False) -> None: + """ + Save sharded model checkpoint under the given checkpointing path. + The following files will be created under the path: + - An index file (pytorch_model.bin.index.json) containing a map between model params/buffers and file names. + - Multiple files that store state tensors of models. + If pipeline parallelism is used, the filenames are in the form of "pytorch_model.-stage-000XX-shard-000XX.bin". + If pipeline parallelism is not used, "pytorch_model.-000XX.bin" + + + Args: + model (nn.Module): Model on local device to be saved. + checkpoint (str): Checkpointing path which should be a directory path. + gather_dtensor (bool, optional): Whether to gather_dtensor, currently not used. Defaults to True. + prefix (str, optional): Perfix of file to save. Defaults to None. + size_per_shard (int, optional): Size per shard in MB. Defaults to 1024. + use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False. + """ + + if os.path.isfile(checkpoint): + logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") + return + + Path(checkpoint).mkdir(parents=True, exist_ok=True) + + # Devices along the same dp_group share the same copies of model. + # So only let the device with dp_rank == 0 save the model. + if self.dp_rank != 0: + return + + # Then collect the sharded parameters & buffers along tp_group. + # Only devices with tp_size == 0 are responsible for model saving. + state_dict_shard = HypridParallelCheckpointIO._model_sharder(model, size_per_shard=size_per_shard) + weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors) + index_file = CheckpointIndexFile(checkpoint) + control_saving = (self.tp_rank == 0) + + if self.pp_size == 1: + # When pipeline is not used, save the model shards as in general checkpointIO + total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=weights_name, + is_master=control_saving, + use_safetensors=use_safetensors) + if control_saving: + index_file.append_meta_data("total_size", total_size) + index_file.write_index_file(save_index_file) + logging.info(f"The model is split into checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}.") + + else: + # When pipeline is used, each stage produces its own shard files and index files. + # Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/ + # After all the state_dicts have been saved, the master rank integrates all the index files into one final index file and deletes the tmp folder. + + final_index_file_path = copy.deepcopy(save_index_file) + tmp_index_file_folder = os.path.join(checkpoint, "tmp_index_files") + Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True) + + # Manage filenames of sharded weights and index file for each pipeline stage. + weights_name = weights_name.replace(".bin", f"-stage-{self.pp_rank:05d}-shard.bin") + weights_name = weights_name.replace(".safetensors", f"-stage-{self.pp_rank:05d}-shard.safetensors") + save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank:05d}.json") + save_index_file = os.path.join("tmp_index_files", save_index_file) + + total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=weights_name, + is_master=control_saving, + use_safetensors=use_safetensors) + if control_saving: + assert self.dp_rank == 0 and self.tp_rank == 0, "The saving process should have both dp_rank and tp_rank as 0." + index_file.append_meta_data("total_size", total_size) + index_file.write_index_file(save_index_file) + else: + return + + dist.barrier(self.pp_group) + + # The global master rank integrates the index files and clean the folder. + if self.pp_rank == 0: + final_index_file = CheckpointIndexFile(checkpoint) + final_index_file.append_meta_data("total_size", 0) + + for filename in os.listdir(tmp_index_file_folder): + stage_index_file = CheckpointIndexFile.from_file(os.path.join(tmp_index_file_folder, filename)) + final_index_file.metadata["total_size"] += stage_index_file.metadata["total_size"] + for weight, weight_filename in stage_index_file.weight_map.items(): + final_index_file.append_weight_map(weight, weight_filename) + + final_index_file.write_index_file(final_index_file_path) + rmtree(tmp_index_file_folder) + logging.info(f"The model is split into checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {final_index_file_path}.") + + def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False): + """ + Load sharded model with the given path to index file of checkpoint folder. + + Args: + model (nn.Module): The model to be loaded. + index_file_path (str): Path to the index file of checkpointing folder. + strict (bool, optional): For name matching during loading state_dict. Defaults to False. + This argument should be manually set to False since params on same device might be stored in different files. + """ + + # Check whether the checkpoint uses safetensors. + use_safetensors = False + if "safetensors" in checkpoint_index_file.name: + use_safetensors = True + + if use_safetensors and not is_safetensors_available(): + raise ImportError("`safe_serialization` requires the `safetensors` library: `pip install safetensors`.") + + # Read checkpoint index file. + ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file) + ckpt_root_path = ckpt_index_file.root_path + weight_map = ckpt_index_file.weight_map + strict = False + + # Load params & buffers to model. + # Keep a record of loaded files so that file will not be repeatedly loaded. + loaded_file = set() + + def _load(name: str): + if name not in weight_map: + raise ValueError(f"{name} is not stored in checkpoint, please check your checkpointing configuration!") + filename = weight_map[name] + + # If this param/buffer has been loaded before, directly return. + if filename in loaded_file: + return + + file_path = os.path.join(ckpt_root_path, filename) + state_dict = load_shard_state_dict(Path(file_path), use_safetensors) + missing_keys = [] + + load_state_dict_into_model(model, + state_dict, + missing_keys=missing_keys, + strict=strict, + load_sub_module=True) + del state_dict + loaded_file.add(filename) + + # Load parameters. + for name, _ in model.named_parameters(): + _load(name) + + # Load buffers. + for name, buf in model.named_buffers(): + if buf is not None and name not in model._non_persistent_buffers_set: + _load(name) + + # Load extra states. + extra_state_key = _EXTRA_STATE_KEY_SUFFIX + if getattr(model.__class__, "get_extra_state", + torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state: + _load(extra_state_key) + + def save_sharded_optimizer(self, + optimizer: Optimizer, + checkpoint: str, + gather_dtensor: bool = True, + prefix: Optional[str] = None, + size_per_shard: int = 1024): + pass + + def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str): + pass + + def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool = True): + # TODO(Baizhou): support this feature after implementing complete state_dict collection + raise NotImplementedError + + def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): + # TODO(Baizhou): support this feature after implementing complete state_dict collection + raise NotImplementedError + + def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool): + # TODO(Baizhou): support this feature after implementing complete state_dict collection + raise NotImplementedError + + def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool): + # TODO(Baizhou): support this feature after implementing complete state_dict collection + raise NotImplementedError + + def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): + """ + Save lr scheduler to checkpoint but only on master process. + """ + if self.coordinator.is_master(): + super().save_lr_scheduler(lr_scheduler, checkpoint) diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index 8837776ae..d04159c54 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -13,7 +13,12 @@ from torch.optim import Optimizer from colossalai.interface import OptimizerWrapper from colossalai.nn.optimizer import ColossalaiOptimizer -from colossalai.tensor.d_tensor import is_distributed_tensor +from colossalai.tensor.d_tensor import ( + is_customized_distributed_tensor, + is_distributed_tensor, + to_global, + to_global_for_customized_distributed_tensor, +) SAFE_WEIGHTS_NAME = "model.safetensors" WEIGHTS_NAME = "pytorch_model.bin" @@ -88,8 +93,28 @@ def is_safetensor_checkpoint(checkpoint_file_path: str) -> bool: return False +def gather_distributed_param(param: torch.Tensor, keep_vars: bool = False): + """ + Gather the complete parameter for saving if passed in param is distributed. + + Args: + param (torch.Tensor): A model parameter, might be d_tensor. + keep_vars (bool, optional): Whether to return the parameter in calculation graph. Defaults to False. + + Returns: + torch.Tensor: the complete parameter + """ + param_ = param if keep_vars else param.detach() + if is_distributed_tensor(param_): + return to_global(param_) + elif is_customized_distributed_tensor(param_): + return to_global_for_customized_distributed_tensor(param_) + else: + return param_ + + # ====================================== -# Helper functions for saving shard file +# Helper classes and functions for saving shard file # ====================================== def unwrap_optimizer(optimizer: OptimizerWrapper): ''' @@ -104,6 +129,31 @@ def unwrap_optimizer(optimizer: OptimizerWrapper): return unwrapped_optim +class StateDictSharder: + + def __init__(self, size_per_shard: int) -> None: + self.max_shard_size = size_per_shard + self.current_block = OrderedDict() + self.current_block_size = 0 + + def append(self, name: str, tensor: torch.Tensor) -> Tuple[Optional[OrderedDict], int]: + tensor_size = calculate_tensor_size(tensor) + ret_block = None + ret_block_size = 0 + + # before we return the current block and create a new block, + # we need to ensure that the current block is not empty + if self.current_block_size + tensor_size > self.max_shard_size and self.current_block_size > 0: + ret_block = self.current_block + ret_block_size = self.current_block_size + self.current_block = OrderedDict() + self.current_block_size = 0 + + self.current_block[name] = tensor + self.current_block_size += tensor_size + return ret_block, ret_block_size + + def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]], checkpoint: str, index_file: "CheckpointIndexFile", @@ -126,9 +176,10 @@ def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]] total_size = 0 for idx, shard_pair in enumerate(sharded_state_dict): - if not is_master: - continue shard, current_size = shard_pair + if not is_master: + del shard + continue shard_file = get_shard_filename(base_filename, idx) total_size = total_size + current_size for key in shard.keys(): @@ -137,6 +188,7 @@ def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]] # Only save on master rank. save_state_dict(shard, checkpoint_file_path, use_safetensors=use_safetensors) + del shard return total_size diff --git a/colossalai/shardformer/layer/parallel_module.py b/colossalai/shardformer/layer/parallel_module.py index bda147b12..4f391920e 100644 --- a/colossalai/shardformer/layer/parallel_module.py +++ b/colossalai/shardformer/layer/parallel_module.py @@ -10,6 +10,7 @@ import torch.nn as nn from torch.distributed import ProcessGroup from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, Module +from colossalai.checkpoint_io.utils import gather_distributed_param from colossalai.tensor.d_tensor import ( distribute_tensor, distribute_tensor_with_customization, @@ -56,13 +57,7 @@ class ParallelModule(nn.Module, ABC): """ for name, param in self._parameters.items(): if param is not None: - param_ = param if keep_vars else param.detach() - if is_distributed_tensor(param_): - destination[prefix + name] = to_global(param_) - elif is_customized_distributed_tensor(param_): - destination[prefix + name] = to_global_for_customized_distributed_tensor(param_) - else: - destination[prefix + name] = param_ + destination[prefix + name] = gather_distributed_param(param, keep_vars=keep_vars) for name, buf in self._buffers.items(): if buf is not None and name not in self._non_persistent_buffers_set: diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index 08384ee82..5aff91f03 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -8,7 +8,7 @@ import torch import torch.distributed as dist import torch.nn as nn -from colossalai.checkpoint_io.utils import calculate_tensor_size +from colossalai.checkpoint_io.utils import StateDictSharder from colossalai.lazy import LazyTensor from colossalai.logging import get_dist_logger from colossalai.nn.parallel.data_parallel import ColoDDP, _cast_float, free_storage @@ -657,7 +657,7 @@ class ZeroDDP(ColoDDP): Yields: Iterator[OrderedDict]: A generator of state dict shard """ - sharder = _StateDictSharder(max_shard_size) + sharder = StateDictSharder(max_shard_size) # get the mapping between copies and fp16 parameters fp16_to_fp32 = dict() @@ -705,30 +705,6 @@ class ZeroDDP(ColoDDP): yield sharder.current_block, sharder.current_block_size -class _StateDictSharder: - - def __init__(self, max_shard_size: int) -> None: - self.max_shard_size = max_shard_size - self.current_block = OrderedDict() - self.current_block_size = 0 - - def append(self, name: str, tensor: torch.Tensor) -> Tuple[Optional[OrderedDict], int]: - tensor_size = calculate_tensor_size(tensor) - ret_block = None - ret_block_size = 0 - - # before we return the current block and create a new block, - # we need to ensure that the current block is not empty - if self.current_block_size + tensor_size > self.max_shard_size and self.current_block_size > 0: - ret_block = self.current_block - ret_block_size = self.current_block_size - self.current_block = OrderedDict() - self.current_block_size = 0 - self.current_block[name] = tensor - self.current_block_size += tensor_size - return ret_block, ret_block_size - - class GeminiDDP(ZeroDDP): def __init__(self, diff --git a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py new file mode 100644 index 000000000..ea0922ef5 --- /dev/null +++ b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py @@ -0,0 +1,116 @@ +import pytest +import torch +import torch.distributed as dist +from torch.optim import Adam +from utils import shared_tempdir + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import HybridParallelPlugin +from colossalai.tensor.d_tensor.api import clear_layout_converter +from colossalai.testing import ( + check_state_dict_equal, + clear_cache_before_run, + parameterize, + rerun_if_address_is_in_use, + spawn, +) +from tests.kit.model_zoo import model_zoo + + +@clear_cache_before_run() +@parameterize('shard', [True]) +@parameterize('model_name', ['transformers_gpt']) +@parameterize('size_per_shard', [32]) +@parameterize('test_config', [{ + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 4, + 'precision': 'fp32', +}, { + 'tp_size': 1, + 'pp_size': 2, + 'num_microbatches': 4, + 'precision': 'fp32', +}, { + 'tp_size': 4, + 'pp_size': 1, + 'precision': 'fp32', +}, { + 'tp_size': 2, + 'pp_size': 1, + 'precision': 'fp32', +}, { + 'tp_size': 2, + 'pp_size': 1, + 'zero_stage': 2, + 'precision': 'fp16', + 'initial_scale': 1 +}]) +def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_config: dict): + + (model_fn, data_gen_fn, output_transform_fn, loss_fn, + _) = next(iter(model_zoo.get_sub_registry(model_name).values())) + criterion = loss_fn + plugin = HybridParallelPlugin(**test_config) + booster = Booster(plugin=plugin) + + def _criterion(outputs, inputs): + outputs = output_transform_fn(outputs) + loss = criterion(outputs) + return loss + + model = model_fn().cuda() + optimizer = Adam(model.parameters(), lr=1e-3) + model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) + + new_model = model_fn().cuda() + new_optimizer = Adam(new_model.parameters(), lr=1e-3) + new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion) + + data = data_gen_fn() + model.train() + if booster.plugin.stage_manager is not None: + for k, v in data.items(): + if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__: + new_shape = [1] * v.dim() + new_shape[0] = 4 + data[k] = v.to('cuda').repeat(*new_shape) + data_iter = iter([data]) + output = booster.execute_pipeline(data_iter, + model, + _criterion, + optimizer, + return_loss=True, + return_outputs=False) + else: + data = {k: v.cuda() for k, v in data.items()} + output = model(**data) + loss = criterion(output) + optimizer.backward(loss) + + optimizer.step() + + with shared_tempdir() as tempdir: + model_ckpt_path = f"{tempdir}/model" + # optimizer_ckpt_path = f"{tempdir}/optimizer" + booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard) + # booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard) + dist.barrier() + booster.load_model(new_model, model_ckpt_path) + check_state_dict_equal(model.unwrap().state_dict(), new_model.unwrap().state_dict(), False) + + clear_layout_converter() + + +def run_dist(rank, world_size, port): + config = {} + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + exam_state_dict() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [4]) +@rerun_if_address_is_in_use() +def test_hybrid_ckpIO(world_size): + spawn(run_dist, world_size) From 376533a56411d3826df2a5b3aabc5471016496bf Mon Sep 17 00:00:00 2001 From: Jianghai <72591262+CjhHa1@users.noreply.github.com> Date: Mon, 28 Aug 2023 10:51:16 +0800 Subject: [PATCH 17/33] [shardformer] zero1+pp and the corresponding tests (#4517) * pause * finish pp+zero1 * Update test_shard_vit.py --- colossalai/pipeline/schedule/one_f_one_b.py | 3 +- colossalai/zero/low_level/low_level_optim.py | 9 ++- .../test_model/test_shard_bert.py | 9 +++ .../test_model/test_shard_bloom.py | 9 +++ .../test_model/test_shard_gpt2.py | 9 +++ .../test_model/test_shard_llama.py | 9 +++ .../test_model/test_shard_opt.py | 9 +++ .../test_model/test_shard_t5.py | 9 +++ .../test_model/test_shard_vit.py | 11 ++- .../test_model/test_shard_whisper.py | 67 ++++++++++--------- 10 files changed, 109 insertions(+), 35 deletions(-) diff --git a/colossalai/pipeline/schedule/one_f_one_b.py b/colossalai/pipeline/schedule/one_f_one_b.py index f5e4929aa..0058873c2 100644 --- a/colossalai/pipeline/schedule/one_f_one_b.py +++ b/colossalai/pipeline/schedule/one_f_one_b.py @@ -128,11 +128,11 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor). """ micro_batch = self.load_micro_batch() - # for the first stage, input_obj is None # for the non-first stage, input_obj is the output of the previous stage and it's must be a dict output_obj = model_forward(model, micro_batch, input_obj) if self.stage_manager.is_last_stage(): + loss = criterion(output_obj, micro_batch) / self.num_microbatches if accum_loss is not None: accum_loss.add_(loss.detach()) @@ -158,7 +158,6 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): # Retain the grad on the input_obj. tree_map(retain_grad, input_obj) - # Backward pass. if output_obj_grad is None: optimizer.backward(output_obj) diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 64d6a5395..a1e85e5b9 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -316,7 +316,6 @@ class LowLevelZeroOptimizer(OptimizerWrapper): def backward(self, loss, retain_graph=False): assert not(self._partition_grads and not self.require_grad_sync), \ "ZeRO2(partition_grads) and gradient accumulation(no_sync) are not compatible" - if self.mixed_precision_mixin is not None: loss = self.mixed_precision_mixin.pre_backward(loss) @@ -333,6 +332,13 @@ class LowLevelZeroOptimizer(OptimizerWrapper): self.zero_grad() + def backward_by_grad(self, tensor, grad): + # in lower stage which grad is transfered by higher stage + # we need to pass the optim state down. + if self.mixed_precision_mixin is not None: + grad = self.mixed_precision_mixin.pre_backward_by_grad(tensor, grad) + torch.autograd.backward(tensor, grad) + def zero_grad(self, set_to_none=True): """ Set parameter gradients to zero. If set_to_none = True, gradient @@ -358,7 +364,6 @@ class LowLevelZeroOptimizer(OptimizerWrapper): def step(self, closure=None): assert closure is None, 'closure is not supported by step()' - if not self.require_grad_sync: return diff --git a/tests/test_shardformer/test_model/test_shard_bert.py b/tests/test_shardformer/test_model/test_shard_bert.py index 76f8c0541..a15645a7f 100644 --- a/tests/test_shardformer/test_model/test_shard_bert.py +++ b/tests/test_shardformer/test_model/test_shard_bert.py @@ -107,6 +107,15 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'zero_stage': 2, 'precision': 'fp16', 'initial_scale': 1 +}, { + 'tp_size': 1, + 'pp_size': 2, + 'num_microbatches': 2, + 'enable_all_optimization': True, + 'use_lazy_init': True, + 'zero_stage': 1, + 'precision': 'fp16', + 'initial_scale': 1 }]) def run_bert_test(test_config): diff --git a/tests/test_shardformer/test_model/test_shard_bloom.py b/tests/test_shardformer/test_model/test_shard_bloom.py index 0e236fd47..590eff642 100644 --- a/tests/test_shardformer/test_model/test_shard_bloom.py +++ b/tests/test_shardformer/test_model/test_shard_bloom.py @@ -110,6 +110,15 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'zero_stage': 2, 'precision': 'fp16', 'initial_scale': 1 +}, { + 'tp_size': 1, + 'pp_size': 2, + 'num_microbatches': 2, + 'enable_all_optimization': True, + 'use_lazy_init': True, + 'zero_stage': 1, + 'precision': 'fp16', + 'initial_scale': 1 }]) def run_bloom_test(test_config): diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index 85d66e493..13458fc54 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -128,6 +128,15 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'zero_stage': 2, 'precision': 'fp16', 'initial_scale': 1 +}, { + 'tp_size': 1, + 'pp_size': 2, + 'num_microbatches': 2, + 'enable_all_optimization': True, + 'use_lazy_init': True, + 'zero_stage': 1, + 'precision': 'fp16', + 'initial_scale': 1 }]) @clear_cache_before_run() def run_gpt2_test(test_config): diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 485d2685e..8dc6376bf 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -142,6 +142,15 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'zero_stage': 2, 'precision': 'fp16', 'initial_scale': 1 +}, { + 'tp_size': 1, + 'pp_size': 2, + 'num_microbatches': 2, + 'enable_all_optimization': True, + 'use_lazy_init': True, + 'zero_stage': 1, + 'precision': 'fp16', + 'initial_scale': 1 }]) def run_llama_test(test_config): diff --git a/tests/test_shardformer/test_model/test_shard_opt.py b/tests/test_shardformer/test_model/test_shard_opt.py index 71483b752..939b2d555 100644 --- a/tests/test_shardformer/test_model/test_shard_opt.py +++ b/tests/test_shardformer/test_model/test_shard_opt.py @@ -135,6 +135,15 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'zero_stage': 2, 'precision': 'fp16', 'initial_scale': 1 +}, { + 'tp_size': 1, + 'pp_size': 2, + 'num_microbatches': 2, + 'enable_all_optimization': True, + 'use_lazy_init': True, + 'zero_stage': 1, + 'precision': 'fp16', + 'initial_scale': 1 }]) def run_opt_test(test_config): sub_model_zoo = model_zoo.get_sub_registry('transformers_opt') diff --git a/tests/test_shardformer/test_model/test_shard_t5.py b/tests/test_shardformer/test_model/test_shard_t5.py index a853f024d..cd3d3d673 100644 --- a/tests/test_shardformer/test_model/test_shard_t5.py +++ b/tests/test_shardformer/test_model/test_shard_t5.py @@ -118,6 +118,15 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'zero_stage': 2, 'precision': 'fp16', 'initial_scale': 1 +}, { + 'tp_size': 1, + 'pp_size': 2, + 'num_microbatches': 2, + 'enable_all_optimization': True, + 'use_lazy_init': True, + 'zero_stage': 1, + 'precision': 'fp16', + 'initial_scale': 1 }]) @clear_cache_before_run() def run_t5_test(test_config): diff --git a/tests/test_shardformer/test_model/test_shard_vit.py b/tests/test_shardformer/test_model/test_shard_vit.py index 0b092966c..d40058bb7 100644 --- a/tests/test_shardformer/test_model/test_shard_vit.py +++ b/tests/test_shardformer/test_model/test_shard_vit.py @@ -45,7 +45,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, if org_model.__class__.__name__ == 'ViTModel': check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) - check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) # unwrap model @@ -97,6 +96,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, torch.cuda.empty_cache() +#TODO: num_microbatch size = 2 inf loss @parameterize('test_config', [{ 'tp_size': 2, 'pp_size': 2, @@ -132,6 +132,15 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'zero_stage': 2, 'precision': 'fp16', 'initial_scale': 1 +}, { + 'tp_size': 1, + 'pp_size': 2, + 'num_microbatches': 4, + 'enable_all_optimization': True, + 'use_lazy_init': False, + 'zero_stage': 1, + 'precision': 'fp16', + 'initial_scale': 1 }]) def run_vit_test(test_config): diff --git a/tests/test_shardformer/test_model/test_shard_whisper.py b/tests/test_shardformer/test_model/test_shard_whisper.py index 6eaed7d37..356ed6405 100644 --- a/tests/test_shardformer/test_model/test_shard_whisper.py +++ b/tests/test_shardformer/test_model/test_shard_whisper.py @@ -112,37 +112,44 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, torch.cuda.empty_cache() - +#TODO fix WhisperForConditionalGeneration enable jit fused operato # TODO(jianghai) fix fp16 -#TODO fix WhisperForConditionalGeneration enable jit fused operator -@parameterize('test_config', [{ - 'tp_size': 2, - 'pp_size': 2, - 'num_microbatches': 2, - 'enable_all_optimization': True, - 'use_lazy_init': True, - 'precision': 'fp32', - 'initial_scale': 1, -}, { - 'tp_size': 1, - 'pp_size': 2, - 'num_microbatches': 4, - 'use_lazy_init': False, - 'precision': 'fp32', - 'initial_scale': 1, -}, { - 'tp_size': 4, - 'pp_size': 1, - 'enable_all_optimization': True, - 'use_lazy_init': False, - 'precision': 'fp32', -}, { - 'tp_size': 1, - 'pp_size': 4, - 'num_microbatches': 4, - 'use_lazy_init': False, - 'precision': 'fp32', -}]) +@parameterize( + 'test_config', + [ + { + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 2, + 'enable_all_optimization': True, + 'use_lazy_init': True, + 'precision': 'fp32', + 'initial_scale': 1, + }, + { + 'tp_size': 1, + 'pp_size': 2, + 'num_microbatches': 4, + 'use_lazy_init': False, + 'precision': 'fp32', + 'initial_scale': 1, + }, + { + 'tp_size': 4, + 'pp_size': 1, + 'enable_all_optimization': True, + 'use_lazy_init': False, + 'precision': 'fp32', + }, + { + 'tp_size': 1, + 'pp_size': 4, + 'num_microbatches': 4, + 'use_lazy_init': False, + 'precision': 'fp32', + }, + # whisper is not supported fp16 for now. + ]) def run_whisper_test(test_config): sub_model_zoo = model_zoo.get_sub_registry('transformers_whisper') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): From c554b7f559b592c4d358db677c87658b11a6341c Mon Sep 17 00:00:00 2001 From: Bin Jia <45593998+FoolPlayer@users.noreply.github.com> Date: Mon, 28 Aug 2023 17:16:40 +0800 Subject: [PATCH 18/33] =?UTF-8?q?[shardformer/fix=20overlap=20bug]=20fix?= =?UTF-8?q?=20overlap=20bug,=20add=20overlap=20as=20an=20option=20in=20sha?= =?UTF-8?q?rdco=E2=80=A6=20(#4516)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix overlap bug and support bert, add overlap as an option in shardconfig * support overlap for chatglm and bloom --- colossalai/shardformer/layer/_operation.py | 53 ++++++++----------- colossalai/shardformer/layer/linear.py | 2 +- colossalai/shardformer/policies/bert.py | 21 ++++++-- colossalai/shardformer/policies/bloom.py | 11 +++- colossalai/shardformer/policies/chatglm2.py | 4 +- colossalai/shardformer/shard/shard_config.py | 9 ++++ .../test_layer/test_linear_1d.py | 2 +- 7 files changed, 63 insertions(+), 39 deletions(-) diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index f1f48273c..55d9413b9 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -211,43 +211,36 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function): handle.wait() else: - # create new stream for calculate the gradient - calculate_stream = torch.cuda.Stream() - - # do all gather in default stream input_ = input_.contiguous() world_size = dist.get_world_size(process_group) tensor_list = [torch.empty_like(input_) for _ in range(world_size)] + + # do all gather in is async way gather_handle = dist.all_gather(tensor_list, input_, group=process_group, async_op=True) - - # calculate gradient in calculate_stream - with torch.cuda.stream(calculate_stream): - # calculate - grad_input = grad_output.matmul(weight) - grad_output = grad_output.contiguous() - # Convert the tensor shapes to 2D for execution compatibility - if len(grad_output.shape) > 2: - grad_output = grad_output.view(-1, grad_output.shape[-1]) - grad_bias = grad_output.sum(dim=0) if use_bias else None - - # prepare data - input_list = [ - item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim) - ] - output = torch.empty(input_.shape, dtype=input_.dtype, device=input_.device).contiguous() - - torch.cuda.current_stream().wait_stream(calculate_stream) + # calculate gradient and prepare data asynchronously with all-gather + # calculate + grad_input = grad_output.matmul(weight) + grad_output = grad_output.contiguous() + # Convert the tensor shapes to 2D for execution compatibility + if len(grad_output.shape) > 2: + grad_output = grad_output.view(-1, grad_output.shape[-1]) + grad_bias = grad_output.sum(dim=0) if use_bias else None + # prepare data + input_list = [ + item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim) + ] + output = torch.empty(input_.shape, dtype=input_.dtype, device=input_.device).contiguous() + # wait until all-gather finished gather_handle.wait() + # do reduce-scatter in async way reducescatter_handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True) - with torch.cuda.stream(calculate_stream): - input_parallel = torch.cat(tensor_list, dim=dim).contiguous() - if len(input_parallel.shape) > 2: - input_parallel = input_parallel.view(-1, input_parallel.shape[-1]) - print(grad_output.shape, input_parallel.shape) - grad_weight = grad_output.t().matmul(input_parallel) - - torch.cuda.current_stream().wait_stream(calculate_stream) + input_parallel = torch.cat(tensor_list, dim=dim).contiguous() + # calculate gradient + if len(input_parallel.shape) > 2: + input_parallel = input_parallel.view(-1, input_parallel.shape[-1]) + grad_weight = grad_output.t().matmul(input_parallel) + # wait until reduce-scatter finished reducescatter_handle.wait() return output, grad_weight, grad_bias, None, None, None, None diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py index 81c3f973f..111d51b3f 100644 --- a/colossalai/shardformer/layer/linear.py +++ b/colossalai/shardformer/layer/linear.py @@ -75,7 +75,7 @@ class Linear1D_Col(ParallelModule): gather_output: bool = False, seq_parallel: bool = False, seq_parallel_dim: int = 1, - overlap: bool = False, + overlap: torch.cuda.Stream = None, skip_bias_add: bool = False, weight: Optional[Parameter] = None, bias_: Optional[Parameter] = None, diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 19dd95fd6..a141b7bd8 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -56,6 +56,7 @@ class BertPolicy(Policy): policy = {} use_sequence_parallel = self.shard_config.enable_sequence_parallelism + overlap = self.shard_config.enable_sequence_overlap if self.shard_config.enable_tensor_parallelism: policy[BertLayer] = ModulePolicyDescription(attribute_replacement={ "attention.self.all_head_size": @@ -71,17 +72,26 @@ class BertPolicy(Policy): SubModuleReplacementDescription( suffix="attention.self.query", target_module=col_nn.Linear1D_Col, - kwargs={"seq_parallel": use_sequence_parallel}, + kwargs={ + "seq_parallel": use_sequence_parallel, + "overlap": overlap + }, ), SubModuleReplacementDescription( suffix="attention.self.key", target_module=col_nn.Linear1D_Col, - kwargs={"seq_parallel": use_sequence_parallel}, + kwargs={ + "seq_parallel": use_sequence_parallel, + "overlap": overlap + }, ), SubModuleReplacementDescription( suffix="attention.self.value", target_module=col_nn.Linear1D_Col, - kwargs={"seq_parallel": use_sequence_parallel}, + kwargs={ + "seq_parallel": use_sequence_parallel, + "overlap": overlap + }, ), SubModuleReplacementDescription( suffix="attention.self.dropout", @@ -99,7 +109,10 @@ class BertPolicy(Policy): SubModuleReplacementDescription( suffix="intermediate.dense", target_module=col_nn.Linear1D_Col, - kwargs={"seq_parallel": use_sequence_parallel}, + kwargs={ + "seq_parallel": use_sequence_parallel, + "overlap": overlap + }, ), SubModuleReplacementDescription( suffix="output.dense", diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index 21db13f6e..7c418d02b 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -45,6 +45,7 @@ class BloomPolicy(Policy): policy = {} use_sequence_parallel = self.shard_config.enable_sequence_parallelism + overlap = self.shard_config.enable_sequence_overlap if self.shard_config.enable_tensor_parallelism: policy[BloomBlock] = ModulePolicyDescription(attribute_replacement={ "self_attention.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, @@ -55,7 +56,10 @@ class BloomPolicy(Policy): SubModuleReplacementDescription( suffix="self_attention.query_key_value", target_module=col_nn.Linear1D_Col, - kwargs={'seq_parallel': use_sequence_parallel}), + kwargs={ + 'seq_parallel': use_sequence_parallel, + 'overlap': overlap + }), SubModuleReplacementDescription( suffix="self_attention.dense", target_module=col_nn.Linear1D_Row, @@ -67,7 +71,10 @@ class BloomPolicy(Policy): SubModuleReplacementDescription( suffix="mlp.dense_h_to_4h", target_module=col_nn.Linear1D_Col, - kwargs={'seq_parallel': use_sequence_parallel}), + kwargs={ + 'seq_parallel': use_sequence_parallel, + 'overlap': overlap + }), SubModuleReplacementDescription( suffix="mlp.dense_4h_to_h", target_module=col_nn.Linear1D_Row, diff --git a/colossalai/shardformer/policies/chatglm2.py b/colossalai/shardformer/policies/chatglm2.py index b0d684a67..5bcbc2acc 100644 --- a/colossalai/shardformer/policies/chatglm2.py +++ b/colossalai/shardformer/policies/chatglm2.py @@ -50,6 +50,7 @@ class ChatGLMPolicy(Policy): policy = {} use_sequence_parallel = self.shard_config.enable_sequence_parallelism + overlap = self.shard_config.enable_sequence_overlap if self.shard_config.enable_tensor_parallelism: policy[ChatGLMModel] = ModulePolicyDescription(attribute_replacement={}, sub_module_replacement=[ @@ -81,7 +82,8 @@ class ChatGLMPolicy(Policy): target_module=col_nn.Linear1D_Col, kwargs={ 'seq_parallel': use_sequence_parallel, - 'seq_parallel_dim': 0 + 'seq_parallel_dim': 0, + 'overlap': overlap }), SubModuleReplacementDescription(suffix="self_attention.dense", target_module=col_nn.Linear1D_Row, diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 900f8475c..c5c3d185e 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -20,6 +20,8 @@ class ShardConfig: enable_tensor_parallelism (bool): Whether to turn on tensor parallelism, default is True. enable_fused_normalization (bool): Whether to use fused layernorm, default is False. enable_all_optimization (bool): Whether to turn on all optimization, default is False. + enable_sequence_parallelism (bool): Whether to turn on sequence parallelism, default is False. + enable_sequence_overlap (bool): Whether to turn on sequence overlap, default is False. """ tensor_parallel_process_group: Optional[ProcessGroup] = None pipeline_stage_manager: Optional[PipelineStageManager] = None @@ -29,6 +31,7 @@ class ShardConfig: enable_flash_attention: bool = False enable_jit_fused: bool = False enable_sequence_parallelism: bool = False + enable_sequence_overlap: bool = False # pipeline_parallel_size: int # data_parallel_size: int @@ -41,6 +44,11 @@ class ShardConfig: return self._tensor_parallel_size def __post_init__(self): + if not self.enable_tensor_parallelism and self.enable_sequence_parallelism: + raise ValueError( + "enable_sequence_parallelism can only be set to True when enable_tensor_parallelism is True") + if not self.enable_sequence_parallelism and self.enable_sequence_overlap: + raise ValueError("enable_sequence_overlap can only be set to True when enable_sequence_parallelism is True") if not self.enable_tensor_parallelism: self._tensor_parallel_size = 1 else: @@ -59,3 +67,4 @@ class ShardConfig: self.enable_flash_attention = True self.enable_jit_fused = True self.enable_sequence_parallelism = True + self.enable_sequence_overlap = True diff --git a/tests/test_shardformer/test_layer/test_linear_1d.py b/tests/test_shardformer/test_layer/test_linear_1d.py index 3ad8f14b9..e6d86d533 100644 --- a/tests/test_shardformer/test_layer/test_linear_1d.py +++ b/tests/test_shardformer/test_layer/test_linear_1d.py @@ -168,7 +168,7 @@ def check_linear_col_plus_row(lazy_init: bool, seq_parallel: bool, overlap: bool @parameterize('lazy_init', [False, True]) @parameterize('seq_parallel', [False, True]) -@parameterize('overlap', [False, True]) +@parameterize('overlap', [True]) def run_dist_linear_test(lazy_init, seq_parallel, overlap): check_linear_1d_col(lazy_init, seq_parallel, overlap) check_linear_1d_row(lazy_init, seq_parallel) From 0387a47e63520bf112f80d094b64e1ae5890d525 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Tue, 29 Aug 2023 11:25:05 +0800 Subject: [PATCH 19/33] [shardformer] fix emerged bugs after updating transformers (#4526) --- colossalai/pipeline/schedule/_utils.py | 5 ++++- tests/test_shardformer/test_model/_utils.py | 6 +++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/colossalai/pipeline/schedule/_utils.py b/colossalai/pipeline/schedule/_utils.py index 3ed923927..5cd934b76 100644 --- a/colossalai/pipeline/schedule/_utils.py +++ b/colossalai/pipeline/schedule/_utils.py @@ -123,7 +123,10 @@ def merge_batch(data: List[Any]) -> Any: merged_data = [] for elem_batch in zip(*flattened_data): if isinstance(elem_batch[0], torch.Tensor): - merged_data.append(torch.cat(elem_batch, dim=0)) + if len(elem_batch[0].shape) == 0: # set loss to None in pipeline outputs + merged_data.append(None) + else: + merged_data.append(torch.cat(elem_batch, dim=0)) else: merged_data.append(list(elem_batch)) return tree_unflatten(merged_data, tree_spec) diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index 811471bec..803afc48a 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -195,7 +195,11 @@ def check_output_hidden_state(org_output: Tensor, sharded_hidden_state = sharded_output.last_hidden_state if stage_manager and stage_manager.is_last_stage(): - sharded_hidden_state = torch.cat([output.last_hidden_state for output in sharded_output['outputs']], dim=dim) + pipeline_output = sharded_output['outputs'] + if isinstance(pipeline_output, List): + sharded_hidden_state = torch.cat([output.last_hidden_state for output in pipeline_output], dim=dim) + else: + sharded_hidden_state = pipeline_output.last_hidden_state assert torch.allclose(org_hidden_state.float(), sharded_hidden_state.float(), atol=atol, rtol=rtol), \ f"shard model's output hidden state is not equal to origin model's last hidden state\n{org_hidden_state}\n{sharded_hidden_state}" From e241b74f24ac4efe4712bcefedfd7f14f3dd7b37 Mon Sep 17 00:00:00 2001 From: Bin Jia <45593998+FoolPlayer@users.noreply.github.com> Date: Tue, 29 Aug 2023 18:30:50 +0800 Subject: [PATCH 20/33] [shardformer] Add overlap support for gpt2 (#4535) * add overlap support for gpt2 * remove unused code * remove unused code --- colossalai/shardformer/layer/_operation.py | 87 ++++++++++++----- .../shardformer/layer/qkv_fused_linear.py | 4 +- .../shardformer/policies/base_policy.py | 19 ---- colossalai/shardformer/policies/gpt2.py | 94 ++++++++++--------- .../test_gpt2_qkv_fused_linear_1d.py | 10 +- 5 files changed, 120 insertions(+), 94 deletions(-) diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index 55d9413b9..f45ccc64b 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -291,12 +291,13 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function): """ @staticmethod - def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim): + def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap): ctx.save_for_backward(input_, weight) ctx.use_bias = bias is not None ctx.process_group = process_group ctx.async_grad_reduce_scatter = async_grad_reduce_scatter ctx.dim = dim + ctx.overlap = overlap input_parallel = _gather(input_, dim, process_group) @@ -312,37 +313,70 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function): use_bias = ctx.use_bias dim = ctx.dim process_group = ctx.process_group + overlap = ctx.overlap - # TODO: overlap SP input with gradient computation - input_parallel = _gather(input_, dim, process_group) + if not overlap: + input_parallel = _gather(input_, dim, process_group) - total_input = input_parallel - grad_input = grad_output.matmul(weight.T) - grad_output = grad_output.contiguous() - # Convert the tensor shapes to 2D for execution compatibility - if len(grad_output.shape) > 2: - grad_output = grad_output.view(-1, grad_output.shape[-1]) - total_input = total_input.view(-1, total_input.shape[-1]) + total_input = input_parallel + grad_input = grad_output.matmul(weight.T) + grad_output = grad_output.contiguous() + # Convert the tensor shapes to 2D for execution compatibility + if len(grad_output.shape) > 2: + grad_output = grad_output.view(-1, grad_output.shape[-1]) + total_input = total_input.view(-1, total_input.shape[-1]) - # TODO: overlap SP input with gradient computation - if ctx.async_grad_reduce_scatter: - # Asynchronous reduce-scatter + if ctx.async_grad_reduce_scatter: + # Asynchronous reduce-scatter + input_list = [ + item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim) + ] + output = torch.empty(input_.shape, dtype=input_parallel.dtype, + device=input_parallel.device).contiguous() + handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True) + # Delay the start of weight gradient computation shortly (3us) to have + # reduce-scatter scheduled first and have GPU resources allocated + _ = torch.empty(1, device=grad_output.device) + 1 + + grad_weight = total_input.t().matmul(grad_output) + grad_bias = grad_output.sum(dim=0) if use_bias else None + + if ctx.async_grad_reduce_scatter: + handle.wait() + + else: + world_size = dist.get_world_size(process_group) + tensor_list = [torch.empty_like(input_) for _ in range(world_size)] + + # do all gather in is async way + gather_handle = dist.all_gather(tensor_list, input_, group=process_group, async_op=True) + # calculate gradient and prepare data asynchronously with all-gather + # calculate + grad_input = grad_output.matmul(weight.T) + grad_output = grad_output.contiguous() + # Convert the tensor shapes to 2D for execution compatibility + if len(grad_output.shape) > 2: + grad_output = grad_output.view(-1, grad_output.shape[-1]) + grad_bias = grad_output.sum(dim=0) if use_bias else None + # prepare data input_list = [ item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim) ] - output = torch.empty(input_.shape, dtype=input_parallel.dtype, device=input_parallel.device).contiguous() - handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True) - # Delay the start of weight gradient computation shortly (3us) to have - # reduce-scatter scheduled first and have GPU resources allocated - _ = torch.empty(1, device=grad_output.device) + 1 + output = torch.empty(input_.shape, dtype=input_.dtype, device=input_.device).contiguous() + # wait until all-gather finished + gather_handle.wait() - grad_weight = total_input.t().matmul(grad_output) - grad_bias = grad_output.sum(dim=0) if use_bias else None + # do reduce-scatter in async way + reducescatter_handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True) + input_parallel = torch.cat(tensor_list, dim=dim).contiguous() + # calculate gradient + if len(input_parallel.shape) > 2: + input_parallel = input_parallel.view(-1, input_parallel.shape[-1]) + grad_weight = input_parallel.t().matmul(grad_output) + # wait until reduce-scatter finished + reducescatter_handle.wait() - if ctx.async_grad_reduce_scatter: - handle.wait() - - return output, grad_weight, grad_bias, None, None, None + return output, grad_weight, grad_bias, None, None, None, None class _SplitForwardGatherBackward(torch.autograd.Function): @@ -510,9 +544,10 @@ def linear_reducescatter_forward_gather_backward(input_, process_group, dim): return _LinearWithReduceScatterForwardGatherBackward.apply(input_, process_group, dim) -def matmul_gather_forward_reducescatter_backward(input_, weight, bias, process_group, async_grad_reduce_scatter, dim): +def matmul_gather_forward_reducescatter_backward(input_, weight, bias, process_group, async_grad_reduce_scatter, dim, + overlap): return _MatmulWithGatherForwardReduceScatterBackward.apply(input_, weight, bias, process_group, - async_grad_reduce_scatter, dim) + async_grad_reduce_scatter, dim, overlap) def gather_forward_split_backward(input_, dim, process_group): diff --git a/colossalai/shardformer/layer/qkv_fused_linear.py b/colossalai/shardformer/layer/qkv_fused_linear.py index ccb2bf7ea..5ce77805f 100644 --- a/colossalai/shardformer/layer/qkv_fused_linear.py +++ b/colossalai/shardformer/layer/qkv_fused_linear.py @@ -177,6 +177,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule): async_communication: bool = False, gather_output: bool = False, seq_parallel: bool = False, + overlap: bool = False, skip_bias_add: bool = False, n_fused: int = 3, weight: Optional[Parameter] = None, @@ -190,6 +191,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule): self.out_features = out_features self.gather_output = gather_output self.seq_parallel = seq_parallel + self.overlap = overlap self.skip_bias_add = skip_bias_add self.device = device self.n_fused = n_fused @@ -308,7 +310,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule): if self.seq_parallel: input_parallel = input_ output_parallel = matmul_gather_forward_reducescatter_backward(input_parallel, self.weight, bias, - self.process_group, True, 1) + self.process_group, True, 1, self.overlap) else: # Set up backprop all-reduce. input_parallel = reduce_backward(input_, self.process_group) diff --git a/colossalai/shardformer/policies/base_policy.py b/colossalai/shardformer/policies/base_policy.py index 7022a1cfd..961c6a525 100644 --- a/colossalai/shardformer/policies/base_policy.py +++ b/colossalai/shardformer/policies/base_policy.py @@ -226,22 +226,3 @@ class Policy(ABC): end_idx = num_layers_per_stage_accumulated[stage + 1] return [start_idx, end_idx] - - def append_seq_parallel_to_policy( - self, - suffix_list: List[str], - module_policy_description: ModulePolicyDescription, - ): - r""" - Append the sequence parallel policy to the policy for the given key. - - Args: - suffix_list (List[str]): the suffix list of the module to be parallelized - policy (Dict[Union[str, nn.Module], ModulePolicyDescription]): the policy to be updated - """ - - for sub_description in module_policy_description.sub_module_replacement: - if (sub_description.suffix in suffix_list): - if sub_description.kwargs is None: - sub_description.kwargs = {} - sub_description.kwargs["seq_parallel"] = True diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index acae26309..5093fd469 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -37,7 +37,8 @@ class GPT2Policy(Policy): from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2Model policy = {} - + use_sequence_parallel = self.shard_config.enable_sequence_parallelism + overlap = self.shard_config.enable_sequence_overlap if self.shard_config.enable_tensor_parallelism: policy[GPT2Model] = ModulePolicyDescription(sub_module_replacement=[ SubModuleReplacementDescription( @@ -50,47 +51,54 @@ class GPT2Policy(Policy): ), ]) - policy[GPT2Block] = ModulePolicyDescription(attribute_replacement={ - "attn.embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, - "attn.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, - "attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, - }, - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="attn.c_attn", - target_module=col_nn.GPT2FusedLinearConv1D_Col, - kwargs={ - "n_fused": 3, - }, - ), - SubModuleReplacementDescription( - suffix="attn.c_proj", - target_module=col_nn.GPT2FusedLinearConv1D_Row, - ), - SubModuleReplacementDescription( - suffix="mlp.c_fc", - target_module=col_nn.GPT2FusedLinearConv1D_Col, - kwargs={ - "n_fused": 1, - }, - ), - SubModuleReplacementDescription( - suffix="mlp.c_proj", - target_module=col_nn.GPT2FusedLinearConv1D_Row, - ), - SubModuleReplacementDescription( - suffix="attn.attn_dropout", - target_module=col_nn.DropoutForParallelInput, - ), - SubModuleReplacementDescription( - suffix="attn.resid_dropout", - target_module=col_nn.DropoutForParallelInput, - ), - SubModuleReplacementDescription( - suffix="mlp.dropout", - target_module=col_nn.DropoutForParallelInput, - ), - ]) + policy[GPT2Block] = ModulePolicyDescription( + attribute_replacement={ + "attn.embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "attn.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="attn.c_attn", + target_module=col_nn.GPT2FusedLinearConv1D_Col, + kwargs={ + "n_fused": 3, + "seq_parallel": use_sequence_parallel, + "overlap": overlap + }, + ), + SubModuleReplacementDescription(suffix="attn.c_proj", + target_module=col_nn.GPT2FusedLinearConv1D_Row, + kwargs={ + "seq_parallel": use_sequence_parallel, + }), + SubModuleReplacementDescription( + suffix="mlp.c_fc", + target_module=col_nn.GPT2FusedLinearConv1D_Col, + kwargs={ + "n_fused": 1, + "seq_parallel": use_sequence_parallel, + "overlap": overlap + }, + ), + SubModuleReplacementDescription(suffix="mlp.c_proj", + target_module=col_nn.GPT2FusedLinearConv1D_Row, + kwargs={ + "seq_parallel": use_sequence_parallel, + }), + SubModuleReplacementDescription( + suffix="attn.attn_dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="attn.resid_dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="mlp.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + ]) # optimization configuration if self.shard_config.enable_fused_normalization: @@ -126,8 +134,6 @@ class GPT2Policy(Policy): if self.shard_config.enable_sequence_parallelism: policy[GPT2Model].method_replacement = {"forward": gpt2_sequence_parallel_forward_fn(self.shard_config)} - suffix_list = ["attn.c_attn", "attn.c_proj", "mlp.c_fc", "mlp.c_proj"] - self.append_seq_parallel_to_policy(suffix_list=suffix_list, module_policy_description=policy[GPT2Block]) return policy diff --git a/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py b/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py index ae6a1dc90..4c0f884a7 100644 --- a/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py +++ b/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py @@ -53,7 +53,7 @@ def rearrange(tensor: torch.Tensor, dim: int): return rearanged_tensor -def check_linear_conv_1d_col(lazy_init: bool, seq_parallel: bool): +def check_linear_conv_1d_col(lazy_init: bool, seq_parallel: bool, overlap: bool): ctx = LazyInitContext() if lazy_init else nullcontext() linear = Conv1D(192, 48).cuda() with ctx: @@ -62,7 +62,8 @@ def check_linear_conv_1d_col(lazy_init: bool, seq_parallel: bool): process_group=None, gather_output=True, seq_parallel=seq_parallel, - n_fused=3) + n_fused=3, + overlap=overlap) assert linear.weight.shape == torch.Size([48, 192]) assert linear.bias.shape == torch.Size([192]) @@ -129,8 +130,9 @@ def check_linear_conv_1d_row(lazy_init: bool, seq_parallel: bool): @parameterize('lazy_init', [False, True]) @parameterize('seq_parallel', [False, True]) -def check_gpt2_qkv_fused_linear_1d(lazy_init: bool, seq_parallel: bool): - check_linear_conv_1d_col(lazy_init, seq_parallel) +@parameterize('overlap', [True]) +def check_gpt2_qkv_fused_linear_1d(lazy_init: bool, seq_parallel: bool, overlap: bool): + check_linear_conv_1d_col(lazy_init, seq_parallel, overlap) check_linear_conv_1d_row(lazy_init, seq_parallel) From d367b8878589449cd5410ac8c4da756de6313aad Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Wed, 30 Aug 2023 14:50:34 +0800 Subject: [PATCH 21/33] [shardformer] fix opt test hanging (#4521) * [shardformer] fix opt test hanging * fix * test * test * test * fix test * fix test * remove print * add fix --- colossalai/shardformer/policies/opt.py | 26 +++---- colossalai/shardformer/policies/t5.py | 25 ++++-- colossalai/shardformer/policies/whisper.py | 18 ++++- tests/test_shardformer/test_model/_utils.py | 52 +++++++++++++ .../test_model/test_shard_bert.py | 62 ++++++++++----- .../test_model/test_shard_bloom.py | 61 ++++++++++----- .../test_model/test_shard_chatglm2.py | 78 +++++++++++-------- .../test_model/test_shard_gpt2.py | 63 ++++++++++----- .../test_model/test_shard_llama.py | 75 ++++++++++-------- .../test_model/test_shard_opt.py | 76 ++++++++++-------- .../test_model/test_shard_t5.py | 52 ++++++++----- .../test_model/test_shard_vit.py | 77 ++++++++++-------- .../test_model/test_shard_whisper.py | 56 +++++++++---- 13 files changed, 472 insertions(+), 249 deletions(-) diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index be9d1c58b..abe491bfa 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -103,21 +103,21 @@ class OPTPolicy(Policy): target_key=OPTDecoderLayer) # use flash attention - # if self.shard_config.enable_flash_attention: - # self.append_or_create_method_replacement(description={ - # 'forward': get_opt_flash_attention_forward(), - # }, - # policy=policy, - # target_key=OPTAttention) + if self.shard_config.enable_flash_attention: + self.append_or_create_method_replacement(description={ + 'forward': get_opt_flash_attention_forward(), + }, + policy=policy, + target_key=OPTAttention) # use jit fused operator - # if self.shard_config.enable_jit_fused: - # self.append_or_create_method_replacement(description={ - # 'forward': get_jit_fused_opt_decoder_layer_forward(), - # 'dropout_add': get_jit_fused_dropout_add_func(), - # }, - # policy=policy, - # target_key=OPTDecoderLayer) + if self.shard_config.enable_jit_fused: + self.append_or_create_method_replacement(description={ + 'forward': get_jit_fused_opt_decoder_layer_forward(), + 'dropout_add': get_jit_fused_dropout_add_func(), + }, + policy=policy, + target_key=OPTDecoderLayer) return policy diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py index 192a1b847..92cbd3f72 100644 --- a/colossalai/shardformer/policies/t5.py +++ b/colossalai/shardformer/policies/t5.py @@ -184,24 +184,33 @@ class T5BasePolicy(Policy): # use flash attention if self.shard_config.enable_flash_attention: - policy[T5Attention] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ 'forward': get_t5_flash_attention_forward(), - }) + }, + policy=policy, + target_key=T5Attention) # use jit operator if self.shard_config.enable_jit_fused: - policy[T5LayerFF] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ 'forward': get_jit_fused_T5_layer_ff_forward(), 'dropout_add': get_jit_fused_dropout_add_func(), - }) - policy[T5LayerSelfAttention] = ModulePolicyDescription(method_replacement={ + }, + policy=policy, + target_key=T5LayerFF) + self.append_or_create_method_replacement(description={ 'forward': get_T5_layer_self_attention_forward(), 'dropout_add': get_jit_fused_dropout_add_func(), - }) - policy[T5LayerCrossAttention] = ModulePolicyDescription(method_replacement={ + }, + policy=policy, + target_key=T5LayerSelfAttention) + self.append_or_create_method_replacement(description={ 'forward': get_T5_layer_cross_attention_forward(), 'dropout_add': get_jit_fused_dropout_add_func(), - }) + }, + policy=policy, + target_key=T5LayerCrossAttention) + return policy def postprocess(self): diff --git a/colossalai/shardformer/policies/whisper.py b/colossalai/shardformer/policies/whisper.py index bffb624d0..5d496f08e 100644 --- a/colossalai/shardformer/policies/whisper.py +++ b/colossalai/shardformer/policies/whisper.py @@ -56,9 +56,6 @@ class WhisperPolicy(Policy): self.shard_config.enable_sequence_parallelism = False warnings.warn( "Whisper dosen't support sequence parallelism now, will ignore the sequence parallelism flag.") - if self.shard_config.enable_jit_fused: - self.shard_config.enable_jit_fused = False - warnings.warn("Whisper dosen't support jit fused operator now, will ignore the jit fused flag.") if self.shard_config.enable_tensor_parallelism: policy[WhisperEncoderLayer] = ModulePolicyDescription(attribute_replacement={ @@ -212,6 +209,21 @@ class WhisperPolicy(Policy): policy=policy, target_key=WhisperAttention) + # use jit fused operator + if self.shard_config.enable_jit_fused: + self.append_or_create_method_replacement(description={ + 'forward': get_jit_fused_whisper_decoder_layer_forward(), + 'dropout_add': get_jit_fused_dropout_add_func(), + }, + policy=policy, + target_key=WhisperDecoderLayer) + self.append_or_create_method_replacement(description={ + 'forward': get_jit_fused_whisper_encoder_layer_forward(), + 'dropout_add': get_jit_fused_dropout_add_func(), + }, + policy=policy, + target_key=WhisperEncoderLayer) + return policy def add_lm_head_policy(self, base_policy): diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index 803afc48a..72bb2b025 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -237,6 +237,43 @@ def check_weight(org_model: Module, f"shard model weight {suffix} is not equal to origin model weight\n{org_weight}\n{sharded_weight}" +def get_grad_tensors_for_check(org_model: Module, + sharded_model: Module, + layer_suffix: List[str], + tp_group: ProcessGroup = None, + dim: int = 0, + atol: float = 1e-5, + rtol: float = 1e-3, + verbose: bool = False, + name: str = None): + + grad_to_check = {} + for suffix in layer_suffix: + org_grad = getattr_(org_model, suffix).weight.grad + shard_grad = getattr_(sharded_model, suffix).weight.grad + shard_weight = getattr_(sharded_model, suffix).weight + if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): + shard_grad_list = [torch.zeros_like(shard_grad).to('cuda') for _ in range(dist.get_world_size(tp_group))] + dist.all_gather(shard_grad_list, shard_grad, tp_group) + shard_grad = torch.cat(shard_grad_list, dim=dim) + + # embedding may be resized when using tensor parallel + if shard_grad.shape[0] > org_grad.shape[0]: + shard_grad = shard_grad[:org_grad.shape[0], :] + if verbose and dist.get_rank() == 0: + print(f"'{suffix}' grad: {org_grad}, {shard_grad}") + + grad_to_check[suffix] = { + "org_grad": org_grad.float(), + "shard_grad": shard_grad.float(), + "rtol": rtol, + "atol": atol + } + + return grad_to_check + + +# used by sam/blip2 def check_grad(org_model: Module, sharded_model: Module, layer_suffix: List[str], @@ -275,3 +312,18 @@ def unwrap_model(module: Module, if module.__class__.__name__ == base_model_class_name: return module return getattr(module, base_model_attribute_name, None) + + +def check_all_grad_tensors(check_tensors): + """ + "org_grad": tensor to be compared from the original model + "shard_grad": tensor to be compared from the sharded model + """ + for suffix, check_info in check_tensors.items(): + org_grad = check_info["org_grad"] + shard_grad = check_info["shard_grad"] + rtol = check_info["rtol"] + atol = check_info["atol"] + assert torch.allclose( + org_grad, shard_grad, atol=atol, rtol=rtol + ), f"error attribute '{suffix}', orgin model grad is not equal to shard model grad\n{org_grad}\n{shard_grad}" diff --git a/tests/test_shardformer/test_model/test_shard_bert.py b/tests/test_shardformer/test_model/test_shard_bert.py index a15645a7f..61881a1f9 100644 --- a/tests/test_shardformer/test_model/test_shard_bert.py +++ b/tests/test_shardformer/test_model/test_shard_bert.py @@ -10,10 +10,11 @@ from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_ad from tests.kit.model_zoo import model_zoo from tests.test_shardformer.test_model._utils import ( build_model_from_hybrid_plugin, - check_grad, + check_all_grad_tensors, check_loss, check_output_hidden_state, check_weight, + get_grad_tensors_for_check, run_forward_backward_with_hybrid_plugin, unwrap_model, ) @@ -33,8 +34,46 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, output_transform_fn, criterion, booster) + stage_manager = booster.plugin.stage_manager tp_group = booster.plugin.tp_group + + bert = unwrap_model(org_model, 'BertModel', 'bert') + sharded_bert = unwrap_model(sharded_model, 'BertModel', 'bert') + + col_layer_for_check = ['encoder.layer[0].output.dense'] + row_layer_for_check = ['embeddings.word_embeddings', 'encoder.layer[0].intermediate.dense'] + + # Save gradient tensors for comparison between the original model and the sharded model before optimizer step. + grads_to_check = {} + if test_config['precision'] == 'fp32': + atol, rtol = 1e-4, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: + col_layer_grads = get_grad_tensors_for_check(bert, + sharded_bert, + col_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=1, + verbose=False) + row_layer_grads = get_grad_tensors_for_check(bert, + sharded_bert, + row_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=0, + verbose=False) + grads_to_check.update(col_layer_grads) + grads_to_check.update(row_layer_grads) + + # optimizer executes step + org_optimizer.step() + sharded_optimizer.step() + # check last hidden state & loss if stage_manager is None or stage_manager.is_last_stage(): if test_config['precision'] == 'fp32': @@ -46,23 +85,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) - bert = unwrap_model(org_model, 'BertModel', 'bert') - sharded_bert = unwrap_model(sharded_model, 'BertModel', 'bert') - - col_layer_for_check = ['encoder.layer[0].output.dense'] - row_layer_for_check = ['embeddings.word_embeddings', 'encoder.layer[0].intermediate.dense'] - - if test_config['precision'] == 'fp32': - atol, rtol = 1e-4, 1e-3 - else: - atol, rtol = 5e-3, 5e-3 - if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: - check_grad(bert, sharded_bert, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False) - check_grad(bert, sharded_bert, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False) - - # check weights after optimizer.step() - org_optimizer.step() - sharded_optimizer.step() + # check weights if test_config['precision'] == 'fp32': atol, rtol = 5e-3, 1e-3 else: @@ -70,6 +93,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, if stage_manager is None or stage_manager.is_first_stage(): check_weight(bert, sharded_bert, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False) + # check grads + check_all_grad_tensors(grads_to_check) + torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_bloom.py b/tests/test_shardformer/test_model/test_shard_bloom.py index 590eff642..f7ab94bc9 100644 --- a/tests/test_shardformer/test_model/test_shard_bloom.py +++ b/tests/test_shardformer/test_model/test_shard_bloom.py @@ -9,10 +9,11 @@ from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_ad from tests.kit.model_zoo import model_zoo from tests.test_shardformer.test_model._utils import ( build_model_from_hybrid_plugin, - check_grad, + check_all_grad_tensors, check_loss, check_output_hidden_state, check_weight, + get_grad_tensors_for_check, run_forward_backward_with_hybrid_plugin, unwrap_model, ) @@ -36,6 +37,43 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, stage_manager = booster.plugin.stage_manager tp_group = booster.plugin.tp_group + # unwrap model + bloom = unwrap_model(org_model, 'BloomModel', 'transformer') + sharded_bloom = unwrap_model(sharded_model, 'BloomModel', 'transformer') + + row_layer_for_check = ['h[0].self_attention.query_key_value', 'word_embeddings'] + col_layer_for_check = ['h[0].self_attention.dense'] + + # Save gradient tensors for comparison between the original model and the sharded model. + grads_to_check = {} + if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: + if test_config['precision'] == 'fp32': + atol, rtol = 1e-6, 1e-5 + else: + atol, rtol = 5e-3, 5e-3 + row_layer_grads = get_grad_tensors_for_check(bloom, + sharded_bloom, + row_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=0, + verbose=False) + col_layer_grads = get_grad_tensors_for_check(bloom, + sharded_bloom, + col_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=1, + verbose=False) + grads_to_check.update(col_layer_grads) + grads_to_check.update(row_layer_grads) + + # optimizer executes step + org_optimizer.step() + sharded_optimizer.step() + # check last hidden state & loss if stage_manager is None or stage_manager.is_last_stage(): if test_config['precision'] == 'fp32': @@ -47,24 +85,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) - # unwrap model - bloom = unwrap_model(org_model, 'BloomModel', 'transformer') - sharded_bloom = unwrap_model(sharded_model, 'BloomModel', 'transformer') - - # check grad - row_layer_for_check = ['h[0].self_attention.query_key_value', 'word_embeddings'] - col_layer_for_check = ['h[0].self_attention.dense'] - if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: - if test_config['precision'] == 'fp32': - atol, rtol = 1e-6, 1e-5 - else: - atol, rtol = 5e-3, 5e-3 - check_grad(bloom, sharded_bloom, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False) - check_grad(bloom, sharded_bloom, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False) - - # check weights after optimizer.step() - org_optimizer.step() - sharded_optimizer.step() if stage_manager is None or stage_manager.is_first_stage(): if test_config['precision'] == 'fp32': atol, rtol = 1e-4, 1e-3 @@ -72,6 +92,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, atol, rtol = 5e-3, 5e-3 check_weight(bloom, sharded_bloom, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False) + # check grads + check_all_grad_tensors(grads_to_check) + torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_chatglm2.py b/tests/test_shardformer/test_model/test_shard_chatglm2.py index a8957d8d3..c5a3e68f7 100644 --- a/tests/test_shardformer/test_model/test_shard_chatglm2.py +++ b/tests/test_shardformer/test_model/test_shard_chatglm2.py @@ -9,10 +9,11 @@ from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_ad from tests.kit.model_zoo import model_zoo from tests.test_shardformer.test_model._utils import ( build_model_from_hybrid_plugin, - check_grad, + check_all_grad_tensors, check_loss, check_output_hidden_state, check_weight, + get_grad_tensors_for_check, run_forward_backward_with_hybrid_plugin, unwrap_model, ) @@ -36,6 +37,44 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, stage_manager = booster.plugin.stage_manager tp_group = booster.plugin.tp_group + # unwrap model + chatglm_model = unwrap_model(org_model, 'ChatGLMModel', 'transformer') + shard_chatglm_model = unwrap_model(sharded_model, 'ChatGLMModel', 'transformer') + + row_layer_for_check = ['encoder.layers[0].self_attention.query_key_value', 'embedding.word_embeddings'] + col_layer_for_check = ['encoder.layers[0].self_attention.dense'] + + # Save gradient tensors for comparison between the original model and the sharded model. + grads_to_check = {} + if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: + if test_config['precision'] == 'fp32': + atol, rtol = 1e-6, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + row_layer_grads = get_grad_tensors_for_check(chatglm_model, + shard_chatglm_model, + row_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=0, + verbose=False) + + col_layer_grads = get_grad_tensors_for_check(chatglm_model, + shard_chatglm_model, + col_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=1, + verbose=False) + grads_to_check.update(col_layer_grads) + grads_to_check.update(row_layer_grads) + + # optimizer executes step + org_optimizer.step() + sharded_optimizer.step() + # check last hidden state & loss if stage_manager is None or stage_manager.is_last_stage(): if test_config['precision'] == 'fp32': @@ -48,39 +87,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) - # unwrap model - chatglm_model = unwrap_model(org_model, 'ChatGLMModel', 'transformer') - shard_chatglm_model = unwrap_model(sharded_model, 'ChatGLMModel', 'transformer') - - # check grad - row_layer_for_check = ['encoder.layers[0].self_attention.query_key_value', 'embedding.word_embeddings'] - col_layer_for_check = ['encoder.layers[0].self_attention.dense'] - if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: - if test_config['precision'] == 'fp32': - atol, rtol = 1e-6, 1e-3 - else: - atol, rtol = 5e-3, 5e-3 - check_grad(chatglm_model, - shard_chatglm_model, - row_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=0, - verbose=False) - - check_grad(chatglm_model, - shard_chatglm_model, - col_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=1, - verbose=False) - - # check weights after optimizer.step() - org_optimizer.step() - sharded_optimizer.step() + # check weights if stage_manager is None or stage_manager.is_first_stage(): if test_config['precision'] == 'fp32': atol, rtol = 1e-4, 1e-3 @@ -95,6 +102,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, dim=1, verbose=False) + # check grads + check_all_grad_tensors(grads_to_check) + torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index 13458fc54..44914721c 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -9,10 +9,11 @@ from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_ad from tests.kit.model_zoo import model_zoo from tests.test_shardformer.test_model._utils import ( build_model_from_hybrid_plugin, - check_grad, + check_all_grad_tensors, check_loss, check_output_hidden_state, check_weight, + get_grad_tensors_for_check, run_forward_backward_with_hybrid_plugin, unwrap_model, ) @@ -36,6 +37,43 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, stage_manager = booster.plugin.stage_manager tp_group = booster.plugin.tp_group + # unwrap model + gpt2 = unwrap_model(org_model, 'GPT2Model', 'transformer') + sharded_gpt2 = unwrap_model(sharded_model, 'GPT2Model', 'transformer') + + col_layer_for_check = ['h[0].mlp.c_fc'] + row_layer_for_check = ['wte', 'h[0].mlp.c_proj'] + + # Save gradient tensors for comparison between the original model and the sharded model. + grads_to_check = {} + if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: + if test_config['precision'] == 'fp32': + atol, rtol = 1e-4, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + col_layer_grads = get_grad_tensors_for_check(gpt2, + sharded_gpt2, + col_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=1, + verbose=False) + row_layer_grads = get_grad_tensors_for_check(gpt2, + sharded_gpt2, + row_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=0, + verbose=False) + grads_to_check.update(col_layer_grads) + grads_to_check.update(row_layer_grads) + + # optimizer executes step + org_optimizer.step() + sharded_optimizer.step() + # check last hidden state & loss if stage_manager is None or stage_manager.is_last_stage(): if test_config['precision'] == 'fp32': @@ -48,25 +86,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) - # unwrap model - gpt2 = unwrap_model(org_model, 'GPT2Model', 'transformer') - sharded_gpt2 = unwrap_model(sharded_model, 'GPT2Model', 'transformer') - - col_layer_for_check = ['h[0].mlp.c_fc'] - row_layer_for_check = ['wte', 'h[0].mlp.c_proj'] - - # check grad - if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: - if test_config['precision'] == 'fp32': - atol, rtol = 1e-4, 1e-3 - else: - atol, rtol = 5e-3, 5e-3 - check_grad(gpt2, sharded_gpt2, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False) - check_grad(gpt2, sharded_gpt2, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False) - - # check weights after optimizer.step() - org_optimizer.step() - sharded_optimizer.step() + # check weights if stage_manager is None or stage_manager.is_first_stage(): if test_config['precision'] == 'fp32': atol, rtol = 5e-3, 1e-3 @@ -74,6 +94,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, atol, rtol = 5e-3, 5e-3 check_weight(gpt2, sharded_gpt2, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False) + # check grads + check_all_grad_tensors(grads_to_check) + torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 8dc6376bf..c9d5d3d08 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -12,10 +12,11 @@ from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_ad from tests.kit.model_zoo import model_zoo from tests.test_shardformer.test_model._utils import ( build_model_from_hybrid_plugin, - check_grad, + check_all_grad_tensors, check_loss, check_output_hidden_state, check_weight, + get_grad_tensors_for_check, run_forward_backward_with_hybrid_plugin, unwrap_model, ) @@ -41,6 +42,43 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, stage_manager = booster.plugin.stage_manager tp_group = booster.plugin.tp_group + # unwrap model + llama_model = unwrap_model(org_model, 'LlamaModel', 'model') + shard_llama_model = unwrap_model(sharded_model, 'LlamaModel', 'model') + + row_layer_for_check = ['layers[0].self_attn.q_proj', 'embed_tokens'] + col_layer_for_check = ['layers[0].self_attn.o_proj'] + + # Save gradient tensors for comparison between the original model and the sharded model before optimizer step. + grads_to_check = {} + if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: + if test_config['precision'] == 'fp32': + atol, rtol = 1e-6, 1e-4 + else: + atol, rtol = 5e-3, 5e-3 + row_layer_grads = get_grad_tensors_for_check(llama_model, + shard_llama_model, + row_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=0, + verbose=False) + col_layer_grads = get_grad_tensors_for_check(llama_model, + shard_llama_model, + col_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=1, + verbose=False) + grads_to_check.update(col_layer_grads) + grads_to_check.update(row_layer_grads) + + # optimizer executes step + org_optimizer.step() + sharded_optimizer.step() + # check last hidden state & loss if stage_manager is None or stage_manager.is_last_stage(): if test_config['precision'] == 'fp32': @@ -53,37 +91,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) - # unwrap model - llama_model = unwrap_model(org_model, 'LlamaModel', 'model') - shard_llama_model = unwrap_model(sharded_model, 'LlamaModel', 'model') - # check grad - row_layer_for_check = ['layers[0].self_attn.q_proj', 'embed_tokens'] - col_layer_for_check = ['layers[0].self_attn.o_proj'] - if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: - if test_config['precision'] == 'fp32': - atol, rtol = 1e-6, 1e-4 - else: - atol, rtol = 5e-3, 5e-3 - check_grad(llama_model, - shard_llama_model, - row_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=0, - verbose=False) - check_grad(llama_model, - shard_llama_model, - col_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=1, - verbose=False) - - # check weights after optimizer.step() - org_optimizer.step() - sharded_optimizer.step() + # check weights if stage_manager is None or stage_manager.is_first_stage(): if test_config['precision'] == 'fp32': atol, rtol = 1e-4, 1e-3 @@ -98,6 +106,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, dim=1, verbose=False) + # check grads + check_all_grad_tensors(grads_to_check) + torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_opt.py b/tests/test_shardformer/test_model/test_shard_opt.py index 939b2d555..8c0432b37 100644 --- a/tests/test_shardformer/test_model/test_shard_opt.py +++ b/tests/test_shardformer/test_model/test_shard_opt.py @@ -11,10 +11,11 @@ from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_ad from tests.kit.model_zoo import model_zoo from tests.test_shardformer.test_model._utils import ( build_model_from_hybrid_plugin, - check_grad, + check_all_grad_tensors, check_loss, check_output_hidden_state, check_weight, + get_grad_tensors_for_check, run_forward_backward_with_hybrid_plugin, unwrap_model, ) @@ -40,6 +41,43 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, stage_manager = booster.plugin.stage_manager tp_group = booster.plugin.tp_group + # unwrap model + opt_model = unwrap_model(org_model, 'OPTModel', 'model') + shard_opt_model = unwrap_model(sharded_model, 'OPTModel', 'model') + + row_layer_for_check = ['decoder.layers[0].self_attn.q_proj', 'decoder.embed_tokens'] # 'decoder.embed_tokens' + col_layer_for_check = ['decoder.layers[0].self_attn.out_proj'] + + # Save gradient tensors for comparison between the original model and the sharded model. + grads_to_check = {} + if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: + if test_config['precision'] == 'fp32': + atol, rtol = 1e-6, 1e-3 + else: + atol, rtol = 4e-2, 4e-2 + row_layer_grads = get_grad_tensors_for_check(opt_model, + shard_opt_model, + row_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=0, + verbose=False) + col_layer_grads = get_grad_tensors_for_check(opt_model, + shard_opt_model, + col_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=1, + verbose=False) + grads_to_check.update(col_layer_grads) + grads_to_check.update(row_layer_grads) + + # optimizer executes step + org_optimizer.step() + sharded_optimizer.step() + # check last hidden state & loss if stage_manager is None or stage_manager.is_last_stage(): if test_config['precision'] == 'fp32': @@ -51,38 +89,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) - # unwrap model - opt_model = unwrap_model(org_model, 'OPTModel', 'model') - shard_opt_model = unwrap_model(sharded_model, 'OPTModel', 'model') - - # check grad - row_layer_for_check = ['decoder.layers[0].self_attn.q_proj', 'decoder.embed_tokens'] # 'decoder.embed_tokens' - col_layer_for_check = ['decoder.layers[0].self_attn.out_proj'] - if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: - if test_config['precision'] == 'fp32': - atol, rtol = 1e-6, 1e-3 - else: - atol, rtol = 3e-2, 3e-2 - check_grad(opt_model, - shard_opt_model, - row_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=0, - verbose=False) - check_grad(opt_model, - shard_opt_model, - col_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=1, - verbose=False) - - # check weights after optimizer.step() - org_optimizer.step() - sharded_optimizer.step() + # check weights if stage_manager is None or stage_manager.is_first_stage(): if test_config['precision'] == 'fp32': atol, rtol = 1e-3, 1e-3 @@ -97,6 +104,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, dim=1, verbose=False) + # check grads + check_all_grad_tensors(grads_to_check) + torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_t5.py b/tests/test_shardformer/test_model/test_shard_t5.py index cd3d3d673..29367031e 100644 --- a/tests/test_shardformer/test_model/test_shard_t5.py +++ b/tests/test_shardformer/test_model/test_shard_t5.py @@ -10,10 +10,11 @@ from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_ad from tests.kit.model_zoo import model_zoo from tests.test_shardformer.test_model._utils import ( build_model_from_hybrid_plugin, - check_grad, + check_all_grad_tensors, check_loss, check_output_hidden_state, check_weight, + get_grad_tensors_for_check, run_forward_backward_with_hybrid_plugin, unwrap_model, ) @@ -37,6 +38,32 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, stage_manager = booster.plugin.stage_manager tp_group = booster.plugin.tp_group + # unwrap model + t5 = unwrap_model(org_model) + sharded_t5 = unwrap_model(sharded_model) + + row_layer_for_check = ['shared', 'encoder.block[0].layer[0].SelfAttention.q'] + + # Save gradient tensors for comparison between the original model and the sharded model before optimizer step. + grads_to_check = {} + if test_config['precision'] == 'fp32': + atol, rtol = 1e-5, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: + row_layer_grads = get_grad_tensors_for_check(t5, + sharded_t5, + row_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=0) + grads_to_check.update(row_layer_grads) + + # optimizer executes step + org_optimizer.step() + sharded_optimizer.step() + # check last hidden state & loss if stage_manager is None or stage_manager.is_last_stage(): if test_config['precision'] == 'fp32': @@ -49,30 +76,17 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) - # unwrap model - t5 = unwrap_model(org_model) - sharded_t5 = unwrap_model(sharded_model) - - row_layer_for_check = ['shared', 'encoder.block[0].layer[0].SelfAttention.q'] - - # check grad + # check weights if test_config['precision'] == 'fp32': - atol, rtol = 1e-5, 1e-3 - else: - atol, rtol = 5e-3, 5e-3 - if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: - check_grad(t5, sharded_t5, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0) - - # check weights after optimizer.step() - org_optimizer.step() - sharded_optimizer.step() - if test_config['precision'] == 'fp32': - atol, rtol = 1e-4, 1e-3 + atol, rtol = 5e-4, 1e-3 else: atol, rtol = 5e-3, 5e-3 if stage_manager is None or stage_manager.is_first_stage(): check_weight(t5, sharded_t5, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False) + # check grads + check_all_grad_tensors(grads_to_check) + torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_vit.py b/tests/test_shardformer/test_model/test_shard_vit.py index d40058bb7..2980c6eea 100644 --- a/tests/test_shardformer/test_model/test_shard_vit.py +++ b/tests/test_shardformer/test_model/test_shard_vit.py @@ -9,10 +9,11 @@ from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_ad from tests.kit.model_zoo import model_zoo from tests.test_shardformer.test_model._utils import ( build_model_from_hybrid_plugin, - check_grad, + check_all_grad_tensors, check_loss, check_output_hidden_state, check_weight, + get_grad_tensors_for_check, run_forward_backward_with_hybrid_plugin, unwrap_model, ) @@ -36,6 +37,44 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, stage_manager = booster.plugin.stage_manager tp_group = booster.plugin.tp_group + # unwrap model + vit_model = unwrap_model(org_model, 'ViTModel', 'vit') + shard_vit_model = unwrap_model(sharded_model, 'ViTModel', 'vit') + + # check grad + row_layer_for_check = ['encoder.layer[0].attention.attention.query', 'embeddings.patch_embeddings.projection'] + col_layer_for_check = ['encoder.layer[0].attention.output.dense'] + + # Save gradient tensors for comparison between the original model and the sharded model before optimizer step. + grads_to_check = {} + if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: + if test_config['precision'] == 'fp32': + atol, rtol = 1e-5, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + row_layer_grads = get_grad_tensors_for_check(vit_model, + shard_vit_model, + row_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=0, + verbose=False) + col_layer_grads = get_grad_tensors_for_check(vit_model, + shard_vit_model, + col_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=1, + verbose=False) + grads_to_check.update(col_layer_grads) + grads_to_check.update(row_layer_grads) + + # optimizer executes step + org_optimizer.step() + sharded_optimizer.step() + # check last hidden state & loss if stage_manager is None or stage_manager.is_last_stage(): if test_config['precision'] == 'fp32': @@ -47,38 +86,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) - # unwrap model - vit_model = unwrap_model(org_model, 'ViTModel', 'vit') - shard_vit_model = unwrap_model(sharded_model, 'ViTModel', 'vit') - - # check grad - row_layer_for_check = ['encoder.layer[0].attention.attention.query', 'embeddings.patch_embeddings.projection'] - col_layer_for_check = ['encoder.layer[0].attention.output.dense'] - if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: - if test_config['precision'] == 'fp32': - atol, rtol = 1e-5, 1e-3 - else: - atol, rtol = 5e-3, 5e-3 - check_grad(vit_model, - shard_vit_model, - row_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=0, - verbose=False) - check_grad(vit_model, - shard_vit_model, - col_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=1, - verbose=False) - - # check weights after optimizer.step() - org_optimizer.step() - sharded_optimizer.step() + # check weights if stage_manager is None or stage_manager.is_first_stage(): if test_config['precision'] == 'fp32': atol, rtol = 5e-3, 1e-3 @@ -93,6 +101,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, dim=1, verbose=False) + # check grads + check_all_grad_tensors(grads_to_check) + torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_whisper.py b/tests/test_shardformer/test_model/test_shard_whisper.py index 356ed6405..a55753018 100644 --- a/tests/test_shardformer/test_model/test_shard_whisper.py +++ b/tests/test_shardformer/test_model/test_shard_whisper.py @@ -15,10 +15,11 @@ from colossalai.testing import ( from tests.kit.model_zoo import model_zoo from tests.test_shardformer.test_model._utils import ( build_model_from_hybrid_plugin, - check_grad, + check_all_grad_tensors, check_loss, check_output_hidden_state, check_weight, + get_grad_tensors_for_check, run_forward_backward_with_hybrid_plugin, ) @@ -41,18 +42,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, stage_manager = booster.plugin.stage_manager tp_group = booster.plugin.tp_group - # check last hidden state & loss - if stage_manager is None or stage_manager.is_last_stage(): - if test_config['precision'] == 'fp32': - atol, rtol = 2e-4, 2e-4 - else: - atol, rtol = 5e-3, 5e-3 - - if org_model.__class__.__name__ == 'WhisperModel': - check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) - - check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) - # unwarp the model if org_model.__class__.__name__ == 'WhisperForConditionalGeneration': whisper = org_model.model @@ -75,19 +64,48 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, #'decoder.layers[0].self_attn.out_proj' ] - # check weights and gradients + # Save gradient tensors for comparison between the original model and the sharded model before optimizer step. + grads_to_check = {} if test_config['precision'] == 'fp32': atol, rtol = 2e-4, 2e-4 else: atol, rtol = 5e-3, 5e-3 if stage_manager is None or stage_manager.is_first_stage(): - check_grad(whisper, sharded_whisper, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1) - check_grad(whisper, sharded_whisper, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0) + row_layer_grads = get_grad_tensors_for_check(whisper, + sharded_whisper, + row_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=1) + col_layer_grads = get_grad_tensors_for_check(whisper, + sharded_whisper, + col_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=0) + grads_to_check.update(col_layer_grads) + grads_to_check.update(row_layer_grads) - # check weights after optimizer.step() + # optimizer executes step org_optimizer.step() sharded_optimizer.step() + + # check last hidden state & loss + if stage_manager is None or stage_manager.is_last_stage(): + if test_config['precision'] == 'fp32': + atol, rtol = 2e-4, 2e-4 + else: + atol, rtol = 5e-3, 5e-3 + + if org_model.__class__.__name__ == 'WhisperModel': + check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) + + check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) + + # check weights if test_config['precision'] == 'fp32': atol, rtol = 5e-4, 5e-4 else: @@ -110,8 +128,12 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, dim=0, verbose=False) + # check grads + check_all_grad_tensors(grads_to_check) + torch.cuda.empty_cache() + #TODO fix WhisperForConditionalGeneration enable jit fused operato # TODO(jianghai) fix fp16 @parameterize( From ec18fc7340f99693f2436e91e1dea99342f476d5 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Wed, 30 Aug 2023 21:29:18 +0800 Subject: [PATCH 22/33] [shardformer] support pp+tp+zero1 tests (#4531) * [shardformer] fix opt test hanging * fix * test * test * test * fix test * fix test * remove print * add fix * [shardformer] pp+tp+zero1 [shardformer] pp+tp+zero1 [shardformer] pp+tp+zero1 [shardformer] pp+tp+zero1 [shardformer] pp+tp+zero1 [shardformer] pp+tp+zero1 * [shardformer] pp+tp+zero1 * [shardformer] pp+tp+zero1 * [shardformer] pp+tp+zero1 * [shardformer] pp+tp+zero1 --- colossalai/zero/low_level/low_level_optim.py | 15 +++++++++++++-- .../test_model/test_shard_bert.py | 9 +++++++++ .../test_model/test_shard_bloom.py | 10 ++++++++++ .../test_model/test_shard_chatglm2.py | 10 ++++++++++ .../test_model/test_shard_gpt2.py | 10 ++++++++++ .../test_model/test_shard_llama.py | 10 ++++++++++ .../test_shardformer/test_model/test_shard_opt.py | 10 ++++++++++ .../test_shardformer/test_model/test_shard_t5.py | 10 ++++++++++ .../test_shardformer/test_model/test_shard_vit.py | 9 +++++++++ .../test_model/test_shard_whisper.py | 11 ++++++++++- 10 files changed, 101 insertions(+), 3 deletions(-) diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index a1e85e5b9..85ac9eb48 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -333,12 +333,23 @@ class LowLevelZeroOptimizer(OptimizerWrapper): self.zero_grad() def backward_by_grad(self, tensor, grad): - # in lower stage which grad is transfered by higher stage - # we need to pass the optim state down. + assert not(self._partition_grads and not self.require_grad_sync), \ + "ZeRO2(partition_grads) and gradient accumulation(no_sync) are not compatible" + if self.mixed_precision_mixin is not None: grad = self.mixed_precision_mixin.pre_backward_by_grad(tensor, grad) torch.autograd.backward(tensor, grad) + if not self.require_grad_sync: + return + self._reduce_grad(self._partition_grads) + + # clear reduced grads + if self._overlap_communication: + torch.cuda.synchronize() + + self.zero_grad() + def zero_grad(self, set_to_none=True): """ Set parameter gradients to zero. If set_to_none = True, gradient diff --git a/tests/test_shardformer/test_model/test_shard_bert.py b/tests/test_shardformer/test_model/test_shard_bert.py index 61881a1f9..0855e2248 100644 --- a/tests/test_shardformer/test_model/test_shard_bert.py +++ b/tests/test_shardformer/test_model/test_shard_bert.py @@ -163,6 +163,15 @@ def run_bert_test(test_config): 'enable_all_optimization': False, 'use_lazy_init': False, 'precision': 'fp32', + }, + { + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 4, + 'enable_all_optimization': False, + 'use_lazy_init': False, + 'precision': 'fp16', + 'zero_stage': 1, 'initial_scale': 1, }, ]) diff --git a/tests/test_shardformer/test_model/test_shard_bloom.py b/tests/test_shardformer/test_model/test_shard_bloom.py index f7ab94bc9..c9ee690c8 100644 --- a/tests/test_shardformer/test_model/test_shard_bloom.py +++ b/tests/test_shardformer/test_model/test_shard_bloom.py @@ -165,6 +165,16 @@ def run_bloom_test(test_config): 'precision': 'fp32', 'initial_scale': 1, }, + { + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 4, + 'enable_all_optimization': False, + 'use_lazy_init': False, + 'precision': 'fp16', + 'zero_stage': 1, + 'initial_scale': 1, + }, ]) def run_bloom_3d_test(test_config): sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom') diff --git a/tests/test_shardformer/test_model/test_shard_chatglm2.py b/tests/test_shardformer/test_model/test_shard_chatglm2.py index c5a3e68f7..05ca05dea 100644 --- a/tests/test_shardformer/test_model/test_shard_chatglm2.py +++ b/tests/test_shardformer/test_model/test_shard_chatglm2.py @@ -165,6 +165,16 @@ def run_chatglm_test(test_config): 'precision': 'fp32', 'initial_scale': 1, }, + { + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 4, + 'enable_all_optimization': False, + 'use_lazy_init': False, + 'precision': 'fp16', + 'zero_stage': 1, + 'initial_scale': 1, + }, ]) def run_chatglm_3d_test(test_config): sub_model_zoo = model_zoo.get_sub_registry('transformers_chatglm') diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index 44914721c..563084ed0 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -183,6 +183,16 @@ def run_gpt2_test(test_config): 'precision': 'fp32', 'initial_scale': 1, }, + { + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 4, + 'enable_all_optimization': False, + 'use_lazy_init': False, + 'precision': 'fp16', + 'zero_stage': 1, + 'initial_scale': 1, + }, ]) @clear_cache_before_run() def run_gpt2_3d_test(test_config): diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index c9d5d3d08..a60150e3c 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -185,6 +185,16 @@ def run_llama_test(test_config): 'precision': 'fp32', 'initial_scale': 1, }, + { + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 4, + 'enable_all_optimization': False, + 'use_lazy_init': False, + 'precision': 'fp16', + 'zero_stage': 1, + 'initial_scale': 1, + }, ]) def run_llama_3d_test(test_config): sub_model_zoo = model_zoo.get_sub_registry('transformers_llama') diff --git a/tests/test_shardformer/test_model/test_shard_opt.py b/tests/test_shardformer/test_model/test_shard_opt.py index 8c0432b37..25b1eefc6 100644 --- a/tests/test_shardformer/test_model/test_shard_opt.py +++ b/tests/test_shardformer/test_model/test_shard_opt.py @@ -174,6 +174,16 @@ def run_opt_test(test_config): 'precision': 'fp32', 'initial_scale': 1, }, + { + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 4, + 'enable_all_optimization': False, + 'use_lazy_init': False, + 'precision': 'fp16', + 'zero_stage': 1, + 'initial_scale': 1, + }, ]) def run_opt_3d_test(test_config): sub_model_zoo = model_zoo.get_sub_registry('transformers_opt') diff --git a/tests/test_shardformer/test_model/test_shard_t5.py b/tests/test_shardformer/test_model/test_shard_t5.py index 29367031e..768cae0a6 100644 --- a/tests/test_shardformer/test_model/test_shard_t5.py +++ b/tests/test_shardformer/test_model/test_shard_t5.py @@ -170,6 +170,16 @@ def run_t5_test(test_config): 'precision': 'fp32', 'initial_scale': 1, }, + { + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 4, + 'enable_all_optimization': False, + 'use_lazy_init': False, + 'precision': 'fp16', + 'zero_stage': 1, + 'initial_scale': 1, + }, ]) def run_t5_3d_test(test_config): sub_model_zoo = model_zoo.get_sub_registry('transformers_t5') diff --git a/tests/test_shardformer/test_model/test_shard_vit.py b/tests/test_shardformer/test_model/test_shard_vit.py index 2980c6eea..15db63bfd 100644 --- a/tests/test_shardformer/test_model/test_shard_vit.py +++ b/tests/test_shardformer/test_model/test_shard_vit.py @@ -176,6 +176,15 @@ def run_vit_test(test_config): 'precision': 'fp32', 'initial_scale': 1, }, + { + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 2, + 'enable_all_optimization': False, + 'use_lazy_init': False, + 'precision': 'fp32', + 'initial_scale': 1, + }, ]) def run_vit_3d_test(test_config): sub_model_zoo = model_zoo.get_sub_registry('transformers_vit') diff --git a/tests/test_shardformer/test_model/test_shard_whisper.py b/tests/test_shardformer/test_model/test_shard_whisper.py index a55753018..d0c04c98f 100644 --- a/tests/test_shardformer/test_model/test_shard_whisper.py +++ b/tests/test_shardformer/test_model/test_shard_whisper.py @@ -107,7 +107,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # check weights if test_config['precision'] == 'fp32': - atol, rtol = 5e-4, 5e-4 + atol, rtol = 1e-3, 1e-3 else: atol, rtol = 5e-3, 5e-3 if stage_manager is None or stage_manager.is_first_stage(): @@ -195,6 +195,15 @@ def run_whisper_test(test_config): 'precision': 'fp32', 'initial_scale': 1, }, + { + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 2, + 'enable_all_optimization': False, + 'use_lazy_init': False, + 'precision': 'fp32', + 'initial_scale': 1, + }, ]) def run_whisper_3d_test(test_config): sub_model_zoo = model_zoo.get_sub_registry('transformers_whisper') From 2c787d7f47f7aa55c27877a66f79e4226d16b92a Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Thu, 31 Aug 2023 09:57:18 +0800 Subject: [PATCH 23/33] [shardformer] fix submodule replacement bug when enabling pp (#4544) --- colossalai/shardformer/shard/sharder.py | 25 ++++++++++--------- ...st_hybrid_parallel_plugin_checkpoint_io.py | 2 ++ .../test_model/test_shard_chatglm2.py | 2 ++ .../test_model/test_shard_gpt2.py | 2 ++ .../test_model/test_shard_opt.py | 2 ++ 5 files changed, 21 insertions(+), 12 deletions(-) diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py index 0ed745a1f..9ed384266 100644 --- a/colossalai/shardformer/shard/sharder.py +++ b/colossalai/shardformer/shard/sharder.py @@ -92,22 +92,21 @@ class ModelSharder(object): param_replacement (List[Callable]): The function list to get parameter shard information in policy method_replacement (Dict[str, Callable]): Key is the method name, value is the method for replacement sub_module_replacement ((List[SubModuleReplacementDescription]): The function list to get sub module shard information in policy + include (Set[nn.Module], optional): The set of modules to keep on current device when pipeline parallel is enabled. Defaults to None """ - # released layers are not shardable - can_replace_param_or_layer = include is None or module in include if (isinstance(origin_cls, str) and origin_cls == module.__class__.__name__) or \ (module.__class__ == origin_cls): if attr_replacement is not None: self._replace_attr(module, attr_replacement) - if param_replacement is not None and can_replace_param_or_layer: + if param_replacement is not None and (include is None or module in include): self._replace_param(module, param_replacement) if method_replacement is not None: self._replace_method(module, method_replacement) - if sub_module_replacement is not None and can_replace_param_or_layer: - self._replace_sub_module(module, sub_module_replacement) + if sub_module_replacement is not None: + self._replace_sub_module(module, sub_module_replacement, include) for name, child in module.named_children(): self._recursive_replace_layer(child, @@ -154,18 +153,17 @@ class ModelSharder(object): bound_method = MethodType(new_method, module) setattr(module, method_name, bound_method) - def _replace_sub_module( - self, - org_layer: nn.Module, - sub_module_replacement: List[SubModuleReplacementDescription], - ) -> None: + def _replace_sub_module(self, + org_layer: nn.Module, + sub_module_replacement: List[SubModuleReplacementDescription], + include: Optional[Set[nn.Module]] = None) -> None: r""" Shard one layer according to the policy, the layer should be the same class as the key in policy's argument_policy return dict Args: org_layer (torch.nn.Module): The origin layer object to shard sub_module_replacement (List[SubModuleReplacementDescription]): The sub module replacement description list - + include (Set[nn.Module], optional): The set of modules to keep on current device when pipeline parallel is enabled. Defaults to None """ for description in sub_module_replacement: suffix = description.suffix @@ -174,9 +172,12 @@ class ModelSharder(object): assert target_module is not None, 'target_module should not be None' - # TODO: support different parallel mode native_sub_module = getattr_(org_layer, suffix, ignore=True) + # Skip replacement if submodule is not kept by current device when pipeline parallel is enabled. + if (include is not None) and (native_sub_module is not None) and (native_sub_module not in include): + continue + assert not isinstance(native_sub_module, target_module), \ f"The module with suffix {suffix} has been replaced, please check the policy" diff --git a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py index ea0922ef5..67d73c31f 100644 --- a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py @@ -7,6 +7,7 @@ from utils import shared_tempdir import colossalai from colossalai.booster import Booster from colossalai.booster.plugin import HybridParallelPlugin +from colossalai.shardformer.layer.utils import Randomizer from colossalai.tensor.d_tensor.api import clear_layout_converter from colossalai.testing import ( check_state_dict_equal, @@ -100,6 +101,7 @@ def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_conf booster.load_model(new_model, model_ckpt_path) check_state_dict_equal(model.unwrap().state_dict(), new_model.unwrap().state_dict(), False) + Randomizer.reset_index() clear_layout_converter() diff --git a/tests/test_shardformer/test_model/test_shard_chatglm2.py b/tests/test_shardformer/test_model/test_shard_chatglm2.py index 05ca05dea..48f651c72 100644 --- a/tests/test_shardformer/test_model/test_shard_chatglm2.py +++ b/tests/test_shardformer/test_model/test_shard_chatglm2.py @@ -4,6 +4,7 @@ from torch import distributed as dist import colossalai from colossalai.logging import disable_existing_loggers +from colossalai.shardformer.layer.utils import Randomizer from colossalai.tensor.d_tensor.api import clear_layout_converter from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn from tests.kit.model_zoo import model_zoo @@ -105,6 +106,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # check grads check_all_grad_tensors(grads_to_check) + Randomizer.reset_index() torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index 563084ed0..115a1bd79 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -4,6 +4,7 @@ from torch import distributed as dist import colossalai from colossalai.logging import disable_existing_loggers +from colossalai.shardformer.layer.utils import Randomizer from colossalai.tensor.d_tensor.api import clear_layout_converter from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn from tests.kit.model_zoo import model_zoo @@ -97,6 +98,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # check grads check_all_grad_tensors(grads_to_check) + Randomizer.reset_index() torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_opt.py b/tests/test_shardformer/test_model/test_shard_opt.py index 25b1eefc6..3e74859ad 100644 --- a/tests/test_shardformer/test_model/test_shard_opt.py +++ b/tests/test_shardformer/test_model/test_shard_opt.py @@ -6,6 +6,7 @@ from torch import distributed as dist import colossalai from colossalai.logging import disable_existing_loggers +from colossalai.shardformer.layer.utils import Randomizer from colossalai.tensor.d_tensor.api import clear_layout_converter from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn from tests.kit.model_zoo import model_zoo @@ -107,6 +108,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # check grads check_all_grad_tensors(grads_to_check) + Randomizer.reset_index() torch.cuda.empty_cache() From c9625dbb6364c10f21828b30bc58e8fbcf22a900 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Thu, 31 Aug 2023 14:50:47 +0800 Subject: [PATCH 24/33] [shardformer] support sharded optimizer checkpointIO of HybridParallelPlugin (#4540) * implement sharded optimizer saving * add more param info * finish implementation of sharded optimizer saving * fix bugs in optimizer sharded saving * add pp+zero test * param group loading * greedy loading of optimizer * fix bug when loading * implement optimizer sharded saving * add optimizer test & arrange checkpointIO utils * fix gemini sharding state_dict * add verbose option * add loading of master params * fix typehint * fix master/working mapping in fp16 amp --- .../booster/plugin/hybrid_parallel_plugin.py | 53 +- .../hybrid_parallel_checkpoint_io.py | 445 +++++++++++++-- colossalai/checkpoint_io/utils.py | 523 +++++++++--------- colossalai/zero/gemini/gemini_ddp.py | 6 +- colossalai/zero/gemini/gemini_optimizer.py | 44 +- ...st_hybrid_parallel_plugin_checkpoint_io.py | 110 ++-- 6 files changed, 812 insertions(+), 369 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index c49b3e182..277843b66 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1,7 +1,7 @@ import random from contextlib import nullcontext from functools import partial -from typing import Any, Callable, Iterator, List, Optional, Tuple, Union +from typing import Any, Callable, Iterator, List, Optional, OrderedDict, Tuple, Union import numpy as np import torch @@ -110,6 +110,36 @@ class HybridParallelModule(ModelWrapper): return module +def get_param_info(optim: Optimizer): + # Get a backup of necessary information of parameters for future use, which includes: + # 1. A complete param_group, with params in the form of param_id + # 2. A mapping from param address (obtained using id(param)) to integer param_id + # 3. A mapping from integer param_id to param address. + # 4. A mapping from param_address (obtained using id(param)) to the original shape of parameter before sharding. + # When Zero is used, the params here are fp16/bf16 model params rather than fp32 master params in optimizer. + + if optim is None: + return {} + param_info = {'param_groups': [], 'param2id': {}, 'id2param': {}, 'param2shape': {}} + start_index = 0 + for group in optim.param_groups: + + packed_group = {k: v for k, v in group.items() if k != 'params'} + packed_group['params'] = [] + + for param_id, param in enumerate(group['params'], start_index): + original_shape = param.shape if isinstance(param, torch.Tensor) else None + packed_group['params'].append(param_id) + param_info['param2id'][id(param)] = param_id + param_info['id2param'][param_id] = id(param) + param_info['param2shape'][id(param)] = original_shape + + param_info['param_groups'].append(packed_group) + start_index += len(group['params']) + + return param_info + + def init_pipeline_optimizer(optim: Optimizer, model: Module): params = set(model.parameters()) new_param_groups = [] @@ -121,7 +151,8 @@ def init_pipeline_optimizer(optim: Optimizer, model: Module): class HybridParallelNaiveOptimizer(OptimizerWrapper): - def __init__(self, optim: Optimizer, model: Module, use_pipeline: bool): + def __init__(self, optim: Optimizer, model: Module, use_pipeline: bool, param_info: OrderedDict): + self.param_info = param_info if use_pipeline: init_pipeline_optimizer(optim, model) super().__init__(optim) @@ -133,6 +164,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer): optim: Optimizer, model: Module, use_pipeline: bool, + param_info: OrderedDict, precision: str = 'fp16', initial_scale: float = 2**16, min_scale: float = 1, @@ -142,6 +174,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer): hysteresis: int = 2, max_scale: float = 2**32, max_norm: float = 0): + self.param_info = param_info if use_pipeline: init_pipeline_optimizer(optim, model) super().__init__(optim, precision, initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, @@ -155,6 +188,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer): optimizer: Optimizer, model: Module, use_pipeline: bool, + param_info: OrderedDict, initial_scale: int = 2**16, # grad scaler config min_scale: int = 1, growth_factor: float = 2., @@ -172,6 +206,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer): dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm tp_process_group: Optional[ProcessGroup] = None, # if using tp forced_dtype: Optional[torch.dtype] = None): + self.param_info = param_info if use_pipeline: init_pipeline_optimizer(optimizer, model) super().__init__(optimizer, initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, @@ -356,6 +391,7 @@ class HybridParallelPlugin(PipelinePluginBase): dataloader: Optional[DataLoader] = None, lr_scheduler: Optional[LRScheduler] = None, ) -> Tuple[Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: + param_info = get_param_info(optimizer) if not isinstance(model, ModelWrapper): use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0 model = HybridParallelModule(model, self.precision, self.shard_config, self.dp_group, use_ddp, @@ -366,25 +402,33 @@ class HybridParallelPlugin(PipelinePluginBase): optimizer = HybridParallelAMPOptimizer(optimizer, model, use_pipeline=self.enable_pipeline_parallelism, + param_info=param_info, precision=self.precision, max_norm=self.max_norm, **self.amp_config) + self.checkpoint_io.link_master_and_working_param(optimizer.working_to_master_map, + optimizer.master_to_working_map) else: optimizer = HybridParallelNaiveOptimizer(optimizer, model, - use_pipeline=self.enable_pipeline_parallelism) + use_pipeline=self.enable_pipeline_parallelism, + param_info=param_info) else: assert self.dp_size > 1, "Please use Zero when data parallel size is greater than 1." assert self.precision != 'fp32', "Please set precision to 'fp16' or 'bf16' when using ZeRO." optimizer = HybridParallelZeroOptimizer(optimizer, model, use_pipeline=self.enable_pipeline_parallelism, + param_info=param_info, dp_process_group=self.dp_group, tp_process_group=self.tp_group, verbose=True, clip_grad_norm=self.max_norm, **self.zero_config, **self.amp_config) + self.checkpoint_io.link_master_and_working_param(optimizer._param_store.working_to_master_param, + optimizer._param_store.master_to_working_param) + return model, optimizer, criterion, dataloader, lr_scheduler def execute_pipeline(self, @@ -461,7 +505,8 @@ class HybridParallelPlugin(PipelinePluginBase): **_kwargs) def get_checkpoint_io(self) -> CheckpointIO: - return HypridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group) + self.checkpoint_io = HypridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage) + return self.checkpoint_io def no_sync(self, model: Module) -> Iterator[None]: raise NotImplementedError diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index 56a89bff7..c128858b1 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -4,7 +4,7 @@ import logging import os from pathlib import Path from shutil import rmtree -from typing import Any, Callable, Iterator, List, Optional, OrderedDict, Tuple, Union +from typing import Dict, Iterator, Optional, OrderedDict, Tuple, Union import torch import torch.distributed as dist @@ -13,29 +13,23 @@ from torch.distributed import ProcessGroup from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler as LRScheduler -from colossalai.cluster import ProcessGroupMesh -from colossalai.tensor.d_tensor import ( - is_customized_distributed_tensor, - is_distributed_tensor, - to_global, - to_global_for_customized_distributed_tensor, -) +from colossalai.interface import OptimizerWrapper from .general_checkpoint_io import GeneralCheckpointIO from .index_file import CheckpointIndexFile from .utils import ( StateDictSharder, - calculate_tensor_size, gather_distributed_param, get_model_base_filenames, get_optimizer_base_filenames, - get_shard_filename, is_safetensors_available, load_shard_state_dict, load_state_dict_into_model, + load_states_into_optimizer, save_param_groups, - save_state_dict, save_state_dict_shards, + search_tp_partition_dim, + sharded_optimizer_loading_epilogue, ) try: @@ -52,9 +46,16 @@ class HypridParallelCheckpointIO(GeneralCheckpointIO): dp_group (ProcessGroup): Process group along data parallel dimension. pp_group (ProcessGroup): Process group along pipeline parallel dimension. tp_group (ProcessGroup): Process group along tensor parallel dimension. + zero_stage (int): The zero stage of plugin. Should be in [0, 1, 2]. + verbose (bool, optional): Whether to print logging massage when saving/loading has been succesfully executed. Defaults to True. """ - def __init__(self, dp_group: ProcessGroup, pp_group: ProcessGroup, tp_group: ProcessGroup) -> None: + def __init__(self, + dp_group: ProcessGroup, + pp_group: ProcessGroup, + tp_group: ProcessGroup, + zero_stage: int, + verbose: bool = True) -> None: super().__init__() self.dp_group = dp_group self.pp_group = pp_group @@ -65,6 +66,10 @@ class HypridParallelCheckpointIO(GeneralCheckpointIO): self.dp_size = dist.get_world_size(dp_group) self.pp_size = dist.get_world_size(pp_group) self.tp_size = dist.get_world_size(tp_group) + self.use_zero = (zero_stage > 0) + self.verbose = verbose + self.working_to_master_map = None + self.master_to_working_map = None @staticmethod def _model_sharder(model: nn.Module, @@ -81,7 +86,7 @@ class HypridParallelCheckpointIO(GeneralCheckpointIO): continue # Gather tensor pieces when using tensor parallel. param_ = gather_distributed_param(param, keep_vars=False) - block, block_size = state_dict_sharder.append(prefix + name, param_) + block, block_size = state_dict_sharder.append_param(prefix + name, param_) if block is not None: yield block, block_size @@ -89,7 +94,7 @@ class HypridParallelCheckpointIO(GeneralCheckpointIO): for name, buf in model.named_buffers(): if buf is not None and name not in model._non_persistent_buffers_set: buffer = buf if keep_vars else buf.detach() - block, block_size = state_dict_sharder.append(prefix + name, buffer) + block, block_size = state_dict_sharder.append_param(prefix + name, buffer) if block is not None: yield block, block_size @@ -98,7 +103,7 @@ class HypridParallelCheckpointIO(GeneralCheckpointIO): if getattr(model.__class__, "get_extra_state", torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state: extra_state = model.get_extra_state() - block, block_size = state_dict_sharder.append(extra_state_key, extra_state) + block, block_size = state_dict_sharder.append_param(extra_state_key, extra_state) if block is not None: yield block, block_size @@ -106,10 +111,44 @@ class HypridParallelCheckpointIO(GeneralCheckpointIO): yield state_dict_sharder.current_block, state_dict_sharder.current_block_size @staticmethod - def _optimizer_sharder(optimizer: Optimizer, size_per_shard: int = 1024): + def _optimizer_sharder(optimizer: OptimizerWrapper, + use_zero: bool, + dp_group: ProcessGroup, + tp_group: ProcessGroup, + master_to_working_map: Optional[Dict[int, torch.Tensor]] = None, + size_per_shard: int = 1024): + # An internel method that breaks state_dict of optimizer into shards within limited size. - # TODO (Baizhou): Implement sharding feature of optimizer. - pass + + state_dict_sharder = StateDictSharder(size_per_shard) + param_info = optimizer.param_info + + for param, state in optimizer.optim.state.items(): + + if param is None: + continue + + if master_to_working_map is not None: + working_param = master_to_working_map[id(param)] + else: + working_param = param + + param_id = param_info['param2id'][id(working_param)] + original_shape = param_info['param2shape'][id(working_param)] + state_ = HypridParallelCheckpointIO.gather_from_sharded_optimizer_state(state, + working_param, + original_shape=original_shape, + dp_group=dp_group, + tp_group=tp_group, + use_zero=use_zero, + inplace=False) + + block, block_size = state_dict_sharder.append_optim_state(param_id, state_) + if block is not None: + yield block, block_size + + # Return the last block in sharder. + yield state_dict_sharder.current_block, state_dict_sharder.current_block_size def save_sharded_model(self, model: nn.Module, @@ -148,7 +187,7 @@ class HypridParallelCheckpointIO(GeneralCheckpointIO): return # Then collect the sharded parameters & buffers along tp_group. - # Only devices with tp_size == 0 are responsible for model saving. + # Only devices with tp_rank == 0 are responsible for model saving. state_dict_shard = HypridParallelCheckpointIO._model_sharder(model, size_per_shard=size_per_shard) weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors) index_file = CheckpointIndexFile(checkpoint) @@ -165,9 +204,10 @@ class HypridParallelCheckpointIO(GeneralCheckpointIO): if control_saving: index_file.append_meta_data("total_size", total_size) index_file.write_index_file(save_index_file) - logging.info(f"The model is split into checkpoint shards. " - f"You can find where each parameters has been saved in the " - f"index located at {save_index_file}.") + if self.verbose: + logging.info(f"The model is split into checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}.") else: # When pipeline is used, each stage produces its own shard files and index files. @@ -212,9 +252,10 @@ class HypridParallelCheckpointIO(GeneralCheckpointIO): final_index_file.write_index_file(final_index_file_path) rmtree(tmp_index_file_folder) - logging.info(f"The model is split into checkpoint shards. " - f"You can find where each parameters has been saved in the " - f"index located at {final_index_file_path}.") + if self.verbose: + logging.info(f"The model is split into checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {final_index_file_path}.") def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False): """ @@ -222,7 +263,7 @@ class HypridParallelCheckpointIO(GeneralCheckpointIO): Args: model (nn.Module): The model to be loaded. - index_file_path (str): Path to the index file of checkpointing folder. + checkpoint_index_file (str): Path to the index file of checkpointing folder. strict (bool, optional): For name matching during loading state_dict. Defaults to False. This argument should be manually set to False since params on same device might be stored in different files. """ @@ -263,7 +304,6 @@ class HypridParallelCheckpointIO(GeneralCheckpointIO): missing_keys=missing_keys, strict=strict, load_sub_module=True) - del state_dict loaded_file.add(filename) # Load parameters. @@ -271,8 +311,11 @@ class HypridParallelCheckpointIO(GeneralCheckpointIO): _load(name) # Load buffers. + non_persistent_buffers = set() + for n, m in model.named_modules(): + non_persistent_buffers |= set('.'.join((n, b)) for b in m._non_persistent_buffers_set) for name, buf in model.named_buffers(): - if buf is not None and name not in model._non_persistent_buffers_set: + if buf is not None and name not in non_persistent_buffers: _load(name) # Load extra states. @@ -281,16 +324,236 @@ class HypridParallelCheckpointIO(GeneralCheckpointIO): torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state: _load(extra_state_key) + # Update master params if mixed-precision training is enabled. + with torch.no_grad(): + if self.working_to_master_map is not None: + for param in model.parameters(): + if (param is None) or (id(param) not in self.working_to_master_map): + continue + master_param = self.working_to_master_map[id(param)] + if self.use_zero: + # master_param is sharded under Zero setting + padding_size = (self.dp_size - param.numel() % self.dp_size) % self.dp_size + if padding_size > 0: + padded_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size]) + else: + padded_param = param.data.view(-1) + sharded_param = padded_param.split(padded_param.numel() // self.dp_size)[self.dp_rank] + master_param.data.copy_(sharded_param.data) + else: + master_param.data.copy_(param.data) + + if self.verbose: + logging.info(f"The model has been successfully loaded from sharded checkpoint: {ckpt_root_path}.") + def save_sharded_optimizer(self, - optimizer: Optimizer, + optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = True, prefix: Optional[str] = None, size_per_shard: int = 1024): - pass + """ + Save sharded optimizer checkpoint under the given checkpointing path. + The following files will be created under the path: + - An index file (pytorch_optim.bin.index.json) containing a map between optimizer states and file names + - A group file (pytorch_optim_group.bin) recording information of param_groups + - Multiple files that store state tensors of optimizers. + If pipeline parallelism is used, the filenames are in the form of "pytorch_optim.-stage-000XX-shard-000XX.bin". + If pipeline parallelism is not used, "pytorch_optim.-000XX.bin" - def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str): - pass + Args: + optimizer (OptimizerWrapper): Optimizer to save sharded state_dict + checkpoint (str): Path to save optimizer state_dict + gather_dtensor (bool): Whether to gather_dtensor, not used + prefix (str): Perfix of file to save + size_per_shard (int): Max file size of each file shard that store state tensors + """ + if os.path.isfile(checkpoint): + logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") + return + + Path(checkpoint).mkdir(parents=True, exist_ok=True) + + # Devices along the same dp_group share the same copies of states when zero is not used. + # In this case only let the device with dp_rank == 0 save the model. + if not self.use_zero and self.dp_rank != 0: + return + + # Then collect the sharded states along dp_group(if using zero)/tp_group. + # Only devices with (dp_rank == 0 and tp_rank == 0) are responsible for states saving. + state_dict_shard = HypridParallelCheckpointIO._optimizer_sharder( + optimizer, + use_zero=self.use_zero, + dp_group=self.dp_group, + tp_group=self.tp_group, + master_to_working_map=self.master_to_working_map, + size_per_shard=size_per_shard) + states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix) + index_file = CheckpointIndexFile(checkpoint) + control_saving = (self.dp_rank == 0 and self.tp_rank == 0) + + if self.pp_size == 1: + # When pipeline is not used, save the optimizer shards as in general checkpointIO + total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=states_name, + is_master=control_saving) + + if control_saving: + # Store param groups. + index_file.append_meta_data("param_groups", param_group_file) + group_file_path = os.path.join(checkpoint, param_group_file) + save_param_groups(optimizer.param_info, group_file_path) + # Store index file. + index_file.append_meta_data("total_size", total_size) + index_file.write_index_file(save_index_file) + if self.verbose: + logging.info(f"The optimizer is going to be split to checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}.") + + else: + # When pipeline is used, each stage produces its own shard files and index files. + # Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/ + # After all the state_dicts have been saved, the master rank integrates all the index files into one final index file and deletes the tmp folder. + + final_index_file_path = copy.deepcopy(save_index_file) + tmp_index_file_folder = os.path.join(checkpoint, "tmp_index_files") + Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True) + + # Manage filenames of sharded weights and index file for each pipeline stage. + states_name = states_name.replace(".bin", f"-stage-{self.pp_rank:05d}-shard.bin") + save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank:05d}.json") + save_index_file = os.path.join("tmp_index_files", save_index_file) + + total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=states_name, + is_master=control_saving) + + if control_saving: + assert self.dp_rank == 0 and self.tp_rank == 0, "The saving process should have both dp_rank and tp_rank as 0." + index_file.append_meta_data("total_size", total_size) + index_file.write_index_file(save_index_file) + else: + return + + dist.barrier(self.pp_group) + + # The global master rank integrates the index files and clean the folder. + if self.pp_rank == 0: + + final_index_file = CheckpointIndexFile(checkpoint) + final_index_file.append_meta_data("total_size", 0) + + for filename in os.listdir(tmp_index_file_folder): + stage_index_file = CheckpointIndexFile.from_file(os.path.join(tmp_index_file_folder, filename)) + final_index_file.metadata["total_size"] += stage_index_file.metadata["total_size"] + for param_id, state_filename in stage_index_file.weight_map.items(): + final_index_file.append_weight_map(param_id, state_filename) + + # Store param groups. + final_index_file.append_meta_data("param_groups", param_group_file) + group_file_path = os.path.join(checkpoint, param_group_file) + save_param_groups(optimizer.param_info, group_file_path) + + final_index_file.write_index_file(final_index_file_path) + rmtree(tmp_index_file_folder) + + if self.verbose: + logging.info(f"The model is split into checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {final_index_file_path}.") + + def load_sharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint_index_file: str, prefix: str = ""): + """ + Load sharded optimizer with the given path to index file of checkpoint folder. + + Args: + optimizer (OptimizerWrapper): The optimizer to be loaded. + checkpoint_index_file (str): Path to the index file of checkpointing folder. + prefix (str): Not used. + """ + + def _get_param_id_from_optimizer_param(param: torch.Tensor, + master_to_working_map: Optional[Dict[int, torch.Tensor]] = None): + if master_to_working_map is not None: + working_param = master_to_working_map[id(param)] + else: + working_param = param + return optimizer.param_info['param2id'][id(working_param)] + + # id_map is a mapping from param ids kept by current pipeline, to their corresponding parameter objects. + # When Zero is used, the mapped parameter objects should be fp32 master parameters. + # IDs should be obtained through saved param2id mapping earlier saved in optimizer.param_info. + id_map = {} + for pg in optimizer.optim.param_groups: + for param in pg['params']: + param_id = _get_param_id_from_optimizer_param(param, self.master_to_working_map) + id_map[param_id] = param + + # Read checkpoint index file. + ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file) + ckpt_root_path = ckpt_index_file.root_path + weight_map = ckpt_index_file.weight_map + weight_map = {int(k): v for k, v in weight_map.items()} # convert saved id from str to int + + # Load param_groups + param_group_path = ckpt_index_file.get_param_group_filename() + if param_group_path is None: + raise RuntimeError(f'Invalid index file path {checkpoint_index_file} for an optimizer. \ + Lacking param group file under current directory.') + saved_groups = torch.load(param_group_path) + + updated_groups = [] + for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups): + # obtain updated param group + new_pg = copy.deepcopy(saved_pg) + new_pg['params'] = old_pg['params'] # The parameters in the same group shouln't change. + updated_groups.append(new_pg) + optimizer.optim.__dict__.update({'param_groups': updated_groups}) + + # Load saved states to optimizer. + # Keep a record of loaded files so that file will not be repeatedly loaded. + loaded_file = set() + for pg in optimizer.optim.param_groups: + for param in pg['params']: + if param is None: + continue + param_id = _get_param_id_from_optimizer_param(param, self.master_to_working_map) + if param_id not in weight_map: + continue + filename = weight_map[param_id] + + # If this param's states has been loaded before, directly return. + if filename in loaded_file: + continue + + file_path = os.path.join(ckpt_root_path, filename) + state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False) + load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True) + loaded_file.add(filename) + + # Then shard the loaded optimizer states if using tp/zero. + for param, state in optimizer.optim.state.items(): + device = param.device + if self.master_to_working_map is not None: + working_param = self.master_to_working_map[id(param)] + else: + working_param = param + original_shape = optimizer.param_info['param2shape'][id(working_param)] + sharded_state = self.shard_from_complete_optimizer_state(state, + current_shape=working_param.shape, + original_shape=original_shape, + device=device, + inplace=True) + optimizer.optim.state[param] = sharded_state + + sharded_optimizer_loading_epilogue(optimizer.optim) + if self.verbose: + logging.info(f"The optimizer has been successfully loaded from sharded checkpoint: {ckpt_root_path}.") def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool = True): # TODO(Baizhou): support this feature after implementing complete state_dict collection @@ -314,3 +577,121 @@ class HypridParallelCheckpointIO(GeneralCheckpointIO): """ if self.coordinator.is_master(): super().save_lr_scheduler(lr_scheduler, checkpoint) + + def link_master_and_working_param(self, working_to_master_map: Dict[Union[int, torch.Tensor], torch.Tensor], + master_to_working_map: Dict[Union[int, torch.Tensor], torch.Tensor]): + """ + Create mappings between working params (for forward/backward) and master params (for optimizer update) with passed in mappings. + This mapping can only be created when mixied precision is used. + The created mappings should be mappings from integer parameter addresses to parameter objects. + + Args: + working_to_master_map (Dict[Union[int, torch.Tensor], torch.Tensor]): A mapping from working parameters objects/addresses to master parameter objects. + master_to_working_map (Dict[Union[int, torch.Tensor], torch.Tensor]): A mapping from master parameters objects/addresses to working parameter objects. + """ + self.working_to_master_map = dict() + for k, v in working_to_master_map.items(): + if isinstance(k, torch.Tensor): + self.working_to_master_map[id(k)] = v + elif isinstance(k, int): + self.working_to_master_map[k] = v + else: + raise ValueError( + f"The passed in mapping should have keys of type 'int' or 'torch.Tensor', but got {type(k)}!") + + self.master_to_working_map = dict() + for k, v in master_to_working_map.items(): + if isinstance(k, torch.Tensor): + self.master_to_working_map[id(k)] = v + elif isinstance(k, int): + self.master_to_working_map[k] = v + else: + raise ValueError( + f"The passed in mapping should have keys of type 'int' or 'torch.Tensor', but got {type(k)}!") + + @staticmethod + def gather_from_sharded_optimizer_state(state: OrderedDict, param: torch.Tensor, original_shape: torch.Size, + dp_group: ProcessGroup, tp_group: ProcessGroup, use_zero: bool, + inplace: bool) -> OrderedDict: + """ + With given parameter and its optimizer states, gather the complete optimizer state for saving. + + Args: + state (OrderedDict): Optimizer states of given parameter, might be distributed among tp/dp group if using TP/Zero. + param (torch.Tensor): The given parameter. It should be working_param when using Zero. + original_shape (torch.Size): The size of parameter before sharding. + dp_group (ProcessGroup): The process group of data parallel. + tp_group (ProcessGroup): The process group of tensor parallel. + use_zero (bool): Whether Zero is used. + inplace (bool): If set to True, will update the values of argument 'state' in place. Else will make a copy of state. + + Returns: + OrderedDict: The complete optimizer state of given parameter. + """ + dp_size = dist.get_world_size(dp_group) + tp_size = dist.get_world_size(tp_group) + current_shape = param.shape + state_ = state if inplace else copy.deepcopy(state) + + for k, v in state_.items(): + if isinstance(v, torch.Tensor) and k != 'step': + + # First gather Zero shards. + if use_zero: + v = v.cuda() + gather_tensor = [torch.zeros_like(v) for _ in range(dp_size)] + dist.all_gather(gather_tensor, v, group=dp_group) + v = torch.stack(gather_tensor).view(-1)[:param.numel()].reshape_as(param) + + # Then gather TP shards. + partition_dim = search_tp_partition_dim(current_shape, original_shape, tp_size) + if partition_dim is not None: + gather_tensor = [torch.zeros_like(v) for _ in range(tp_size)] + dist.all_gather(gather_tensor, v, group=tp_group) + v = torch.cat(gather_tensor, dim=partition_dim) + + state_[k] = v.detach().clone().cpu() + + return state_ + + def shard_from_complete_optimizer_state(self, state: OrderedDict, current_shape: torch.Size, + original_shape: torch.Size, device: torch.device, + inplace: bool) -> OrderedDict: + """ + With complete optimizer states of a specific parameter loaded from checkpoint, + slice out the sharded optimizer states kept by current device. + + Args: + state (OrderedDict): Complete optimizer states of a given parameter, loaded from checkpoint. + current_shape (torch.Size): The size of parameter after sharding. + original_shape (torch.Size): The size of parameter before sharding. + device (torch.device): The destination device of loaded optimizer states. + inplace (bool): If set to True, will update the values of argument 'state' in place. Else will make a copy of state. + + Returns: + OrderedDict: The sharded optimizer state of the given parameter. + """ + state_ = state if inplace else copy.deepcopy(state) + + for k, v in state_.items(): + if isinstance(v, torch.Tensor) and k != 'step': + + # Shard state along tensor parallel group. + partition_dim = search_tp_partition_dim(current_shape, original_shape, self.tp_size) + if partition_dim is not None: + slice_size = current_shape[partition_dim] + v = v.split(slice_size, dim=partition_dim)[self.tp_rank] + + # Shard state along data parallel group when using Zero. + if self.use_zero: + padding_size = (self.dp_size - v.numel() % self.dp_size) % self.dp_size + with torch.no_grad(): + v = v.flatten() + if padding_size > 0: + v = torch.nn.functional.pad(v, [0, padding_size]) + slice_size = v.numel() // self.dp_size + v = v.split(slice_size, dim=0)[self.dp_rank] + + state_[k] = v.detach().clone().to(device) + + return state_ diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index d04159c54..0025d07df 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -1,4 +1,5 @@ # coding=utf-8 +import copy import os import re from collections import abc as container_abcs @@ -8,7 +9,9 @@ from pathlib import Path from typing import Iterator, List, Mapping, Optional, OrderedDict, Tuple import torch +import torch.distributed as dist import torch.nn as nn +from torch.distributed import ProcessGroup from torch.optim import Optimizer from colossalai.interface import OptimizerWrapper @@ -93,24 +96,31 @@ def is_safetensor_checkpoint(checkpoint_file_path: str) -> bool: return False -def gather_distributed_param(param: torch.Tensor, keep_vars: bool = False): +def search_tp_partition_dim(current_shape: torch.Size, original_shape: torch.Size, tp_size: int) -> Optional[int]: """ - Gather the complete parameter for saving if passed in param is distributed. + Given the current shape of parameter and the shape of parameter before sharding, + return the dimension along which the parameter is sharded when using tensor parallel. + If tensor parallel is not used, return None. Args: - param (torch.Tensor): A model parameter, might be d_tensor. - keep_vars (bool, optional): Whether to return the parameter in calculation graph. Defaults to False. + current_shape (torch.Size): The current shape of parameter after sharding. + original_shape (torch.Size): The shape of parameter before sharding. + tp_size (int): The size of tp group. Returns: - torch.Tensor: the complete parameter + Optional[int]: The dimension along which parameter is partitioned. """ - param_ = param if keep_vars else param.detach() - if is_distributed_tensor(param_): - return to_global(param_) - elif is_customized_distributed_tensor(param_): - return to_global_for_customized_distributed_tensor(param_) - else: - return param_ + partition_dim = None + for dim, length in enumerate(original_shape): + if length > current_shape[dim]: + partition_dim = dim + break + if partition_dim is not None: + assert original_shape[partition_dim] == tp_size * current_shape[partition_dim], \ + f"The parameter isn't evenly distributed among tensor parallel group: \ + shape before sharding {original_shape}, shape after sharding {current_shape}" + + return partition_dim # ====================================== @@ -136,7 +146,8 @@ class StateDictSharder: self.current_block = OrderedDict() self.current_block_size = 0 - def append(self, name: str, tensor: torch.Tensor) -> Tuple[Optional[OrderedDict], int]: + def append_param(self, name: str, tensor: torch.Tensor) -> Tuple[Optional[OrderedDict], int]: + tensor_size = calculate_tensor_size(tensor) ret_block = None ret_block_size = 0 @@ -153,6 +164,64 @@ class StateDictSharder: self.current_block_size += tensor_size return ret_block, ret_block_size + def append_optim_state(self, param_id: int, state: OrderedDict) -> Tuple[Optional[OrderedDict], int]: + + # A state might contain more than one tensors. + # e.g. each Adam state includes: 'step', 'exp_avg', 'exp_avg_sq' + state_size = 0 + isDTensor = False + for state_tensor in state.values(): + + # When state_tensor is not of Tensor class, + # e.g., a SGD optimizer with momentum set to 0 can have None as state + # The calculation of tensor size should be skipped to avoid error. + if not isinstance(state_tensor, torch.Tensor): + continue + + # If the states are stored as DTensors, mark isDTensor as true. + if is_distributed_tensor(state_tensor): + isDTensor = True + state_size += calculate_tensor_size(state_tensor) + + ret_block = None + ret_block_size = 0 + + # directly return if state is stored as distributed tensor + if isDTensor: + return ret_block, ret_block_size + + # before we return the current block and create a new block, + # we need to ensure that the current block is not empty + if self.current_block_size + state_size > self.max_shard_size and self.current_block_size > 0: + ret_block = self.current_block + ret_block_size = self.current_block_size + self.current_block = OrderedDict() + self.current_block_size = 0 + + self.current_block[param_id] = state + self.current_block_size += state_size + return ret_block, ret_block_size + + +def gather_distributed_param(param: torch.Tensor, keep_vars: bool = False) -> torch.Tensor: + """ + Gather the complete parameter for saving if passed in param is distributed under tp setting. + + Args: + param (torch.Tensor): A model parameter, might be d_tensor. + keep_vars (bool, optional): Whether to return the parameter in calculation graph. Defaults to False. + + Returns: + torch.Tensor: the complete parameter + """ + param_ = param if keep_vars else param.detach() + if is_distributed_tensor(param_): + return to_global(param_) + elif is_customized_distributed_tensor(param_): + return to_global_for_customized_distributed_tensor(param_) + else: + return param_ + def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]], checkpoint: str, @@ -198,28 +267,17 @@ def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a given size. """ - current_block = {} - current_block_size = 0 + state_dict_sharder = StateDictSharder(max_shard_size) for key, weight in state_dict.items(): - ret_block = None - ret_block_size = 0 if not is_distributed_tensor(weight): - weight_size = calculate_tensor_size(weight) + block, block_size = state_dict_sharder.append_param(key, weight) - # If this weight is going to tip up over the maximal size, we split. - if current_block_size + weight_size > max_shard_size and current_block_size > 0: - ret_block = current_block - ret_block_size = current_block_size - current_block = {} - current_block_size = 0 - current_block[key] = weight - current_block_size += weight_size + if block != None: + yield block, block_size - if ret_block != None: - yield ret_block, ret_block_size - - yield current_block, current_block_size + # Return the last block in sharder. + yield state_dict_sharder.current_block, state_dict_sharder.current_block_size def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]: @@ -230,212 +288,15 @@ def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) -> # Only split state_dict['state']; state_dict['param_group'] is not considered in this function. states = state_dict['state'] - - current_block = {} - current_block_size = 0 + state_dict_sharder = StateDictSharder(max_shard_size) for param_id, state in states.items(): + block, block_size = state_dict_sharder.append_optim_state(param_id, state) + if block != None: + yield block, block_size - ret_block = None - ret_block_size = 0 - - # A state might contain more than one tensors. - # e.g. each Adam state includes: 'step', 'exp_avg', 'exp_avg_sq' - state_size = 0 - isDTensor = False - for state_tensor in state.values(): - - # When state_tensor is not of Tensor class, - # e.g., a SGD optimizer with momentum set to 0 can have None as state - # The calculation of tensor size should be skipped to avoid error. - if not isinstance(state_tensor, torch.Tensor): - continue - - # If the states are stored as DTensors, mark isDTensor as true. - if is_distributed_tensor(state_tensor): - isDTensor = True - state_size += calculate_tensor_size(state_tensor) - - if not isDTensor: - - if current_block_size + state_size > max_shard_size and current_block_size > 0: - ret_block = current_block - ret_block_size = current_block_size - current_block = {} - current_block_size = 0 - - current_block[param_id] = state - current_block_size += state_size - - if ret_block != None: - yield ret_block, ret_block_size - - yield current_block, current_block_size - - -def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool = False): - """ - load shard state dict into model - """ - if use_safetensors and not checkpoint_file.suffix == ".safetensors": - raise Exception("load the model using `safetensors`, but no file endwith .safetensors") - if use_safetensors: - from safetensors.torch import load_file as safe_load_file - from safetensors.torch import safe_open - with safe_open(checkpoint_file, framework="pt") as f: - metadata = f.metadata() - if metadata["format"] != "pt": - raise NotImplementedError( - f"Conversion from a {metadata['format']} safetensors archive to PyTorch is not implemented yet.") - return safe_load_file(checkpoint_file) - else: - return torch.load(checkpoint_file) - - -def load_state_dict_into_model(model: nn.Module, - state_dict: torch.Tensor, - missing_keys: List, - strict: bool = False, - load_sub_module: bool = True): - r"""Copies parameters and buffers from :attr:`state_dict` into - this module and its descendants. - - Args: - state_dict (dict): a dict containing parameters and - persistent buffers. - """ - if not isinstance(state_dict, Mapping): - raise TypeError("Expected state_dict to be dict-like, got {}.".format(type(state_dict))) - - unexpected_keys: List[str] = [] - sub_missing_keys: List[str] = [] - error_msgs: List[str] = [] - - # copy state_dict so _load_from_state_dict can modify it - metadata = getattr(state_dict, '_metadata', None) - state_dict = OrderedDict(state_dict) - if metadata is not None: - state_dict._metadata = metadata - - def load(module: nn.Module, state_dict, prefix="", load_sub_module: bool = True): - local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) - args = (state_dict, prefix, local_metadata, True, sub_missing_keys, [], error_msgs) - # Parameters of module and children will start with prefix. We can exit early if there are none in this - # state_dict - if len([key for key in state_dict if key.startswith(prefix)]) > 0: - module._load_from_state_dict(*args) - if load_sub_module: - for name, child in module._modules.items(): - if child is not None: - load(child, state_dict, prefix + name + ".") - - load(model, state_dict, "", load_sub_module) - del load - - missing_keys = missing_keys.append(sub_missing_keys) - - if strict: - if len(unexpected_keys) > 0: - error_msgs = 'Unexpected key(s) in state_dict: {}. '.format(', '.join( - '"{}"'.format(k) for k in unexpected_keys)) - raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( - model.__class__.__name__, "\n\t".join(error_msgs))) - - -def load_param_groups_into_optimizer(optimizer: Optimizer, param_group_path: str) -> dict: - """ - Load information of param_groups into an initialized optimizer. - """ - - # Load list of param_groups from given file path. - # The params in saved_groups are in the form of integer indices. - saved_groups = torch.load(param_group_path) - if not isinstance(saved_groups, List): - raise ValueError(f'The param_groups saved at {param_group_path} is not of List type') - - # The params in param_groups are in the form of pytorch tensors. - # For more details, please view source code of Optimizer class in pytorch. - param_groups = optimizer.param_groups - - # Check the compatibility of saved_groups and param_groups. - if len(param_groups) != len(saved_groups): - raise ValueError("loaded state dict has a different number of original parameter groups") - param_lens = (len(g['params']) for g in param_groups) - saved_lens = (len(g['params']) for g in saved_groups) - if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)): - raise ValueError("loaded state dict contains a parameter group " - "that doesn't match the size of optimizer's group") - - # Creating mapping from id to parameters. - id_map = { - old_id: p for old_id, p in zip(chain.from_iterable((g['params'] for g in saved_groups - )), chain.from_iterable((g['params'] for g in param_groups))) - } - - # Update parameter groups, setting their 'params' value. - def update_group(group, new_group): - new_group['params'] = group['params'] - return new_group - - updated_groups = [update_group(g, ng) for g, ng in zip(param_groups, saved_groups)] - - optimizer.__dict__.update({'param_groups': updated_groups}) - return id_map - - -def load_states_into_optimizer(optimizer: Optimizer, state_dict: dict, id_map: dict): - r"""Copies states from `state_dict` into an Optimizer object. - - Args: - optimizer(Optimizer): An initialized Optimizer object to be loaded - state_dict(dict): a mapping from tensor index (an integer) - to its states to be loaded (a mapping from state name to a tensor). - id_map(dict): a mapping from tensor index (an integer) - to its corresponding parameter (a tensor) whose states will be updated. - """ - - def cast(param, value, key=None): - r"""Make a deep copy of value, casting all tensors to device of param.""" - if isinstance(value, torch.Tensor): - # Floating-point types are a bit special here. They are the only ones - # that are assumed to always match the type of params. - # Make sure state['step'] is not casted https://github.com/pytorch/pytorch/issues/74424 - if (key != "step"): - if param.is_floating_point(): - value = value.to(param.dtype) - value = value.to(param.device) - return value - elif isinstance(value, dict): - return {k: cast(param, v, key=k) for k, v in value.items()} - elif isinstance(value, container_abcs.Iterable): - return type(value)(cast(param, v) for v in value) - else: - return value - - # Copy state assigned to params (and cast tensors to appropriate types). - # State that is not assigned to params is copied as is (needed for - # backward compatibility). - new_states = defaultdict(dict) - for k, v in state_dict.items(): - if k in id_map: - param = id_map[k] - new_states[param] = cast(param, v) - else: - new_states[k] = v - - optimizer.state.update(new_states) - - -def sharded_optimizer_loading_epilogue(optimizer: Optimizer): - r"""Do the cleaning up work after state_dict has been loaded into optimizer - - Args: - optimizer(Optimizer): An optimizer object whose state has just been loaded. - """ - - # Do the cleaning up as in src code of Pytorch. - optimizer._hook_for_profile() # To support multiprocessing pickle/unpickle. - optimizer.defaults.setdefault('differentiable', False) + # Return the last block in sharder. + yield state_dict_sharder.current_block, state_dict_sharder.current_block_size # ====================================== @@ -565,38 +426,180 @@ def generate_dtensor_file_name(param_name: str, index: int, use_safetensors: boo return f'{param_name}.{index}.{suffix}' -def save_state_dict_as_shard( - state_dict: dict, - checkpoint_path: str, - index: int, - total_number: int, - use_safetensors: bool, - prefix: str = None, -) -> None: - """ - Save state dict as shard. - - Args: - state_dict (dict): state dict. - checkpoint_path (str): path to the checkpoint file. - index (int): index of the shard. - total_number (int): total number of shards. - prefix (str): prefix of the shard file name. - use_safetensors (bool): whether to use safetensors to save the checkpoint. - """ - # generate the shard name - shard_file_name = generate_checkpoint_shard_file_name(index, total_number, use_safetensors, prefix) - shard_file_path = Path(checkpoint_path).joinpath(shard_file_name).absolute() - - # save the shard - save_state_dict(state_dict, str(shard_file_path), use_safetensors) - - # ======================================== # Helper functions for loading state dict # ======================================== +def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool = False): + """ + load shard state dict into model + """ + if use_safetensors and not checkpoint_file.suffix == ".safetensors": + raise Exception("load the model using `safetensors`, but no file endwith .safetensors") + if use_safetensors: + from safetensors.torch import load_file as safe_load_file + from safetensors.torch import safe_open + with safe_open(checkpoint_file, framework="pt") as f: + metadata = f.metadata() + if metadata["format"] != "pt": + raise NotImplementedError( + f"Conversion from a {metadata['format']} safetensors archive to PyTorch is not implemented yet.") + return safe_load_file(checkpoint_file) + else: + return torch.load(checkpoint_file) + + +def load_state_dict_into_model(model: nn.Module, + state_dict: torch.Tensor, + missing_keys: List, + strict: bool = False, + load_sub_module: bool = True): + r"""Copies parameters and buffers from :attr:`state_dict` into + this module and its descendants. + + Args: + state_dict (dict): a dict containing parameters and + persistent buffers. + """ + if not isinstance(state_dict, Mapping): + raise TypeError("Expected state_dict to be dict-like, got {}.".format(type(state_dict))) + + unexpected_keys: List[str] = [] + sub_missing_keys: List[str] = [] + error_msgs: List[str] = [] + + # copy state_dict so _load_from_state_dict can modify it + metadata = getattr(state_dict, '_metadata', None) + state_dict = OrderedDict(state_dict) + if metadata is not None: + state_dict._metadata = metadata + + def load(module: nn.Module, state_dict, prefix="", load_sub_module: bool = True): + local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) + args = (state_dict, prefix, local_metadata, True, sub_missing_keys, [], error_msgs) + # Parameters of module and children will start with prefix. We can exit early if there are none in this + # state_dict + if len([key for key in state_dict if key.startswith(prefix)]) > 0: + module._load_from_state_dict(*args) + if load_sub_module: + for name, child in module._modules.items(): + if child is not None: + load(child, state_dict, prefix + name + ".") + + load(model, state_dict, "", load_sub_module) + del load + + missing_keys = missing_keys.append(sub_missing_keys) + + if strict: + if len(unexpected_keys) > 0: + error_msgs = 'Unexpected key(s) in state_dict: {}. '.format(', '.join( + '"{}"'.format(k) for k in unexpected_keys)) + raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( + model.__class__.__name__, "\n\t".join(error_msgs))) + + +def load_param_groups_into_optimizer(optimizer: Optimizer, param_group_path: str) -> dict: + """ + Load information of param_groups into an initialized optimizer. + """ + + # Load list of param_groups from given file path. + # The params in saved_groups are in the form of integer indices. + saved_groups = torch.load(param_group_path) + if not isinstance(saved_groups, List): + raise ValueError(f'The param_groups saved at {param_group_path} is not of List type') + + # The params in param_groups are in the form of pytorch tensors. + # For more details, please view source code of Optimizer class in pytorch. + param_groups = optimizer.param_groups + + # Check the compatibility of saved_groups and param_groups. + if len(param_groups) != len(saved_groups): + raise ValueError("loaded state dict has a different number of original parameter groups") + param_lens = (len(g['params']) for g in param_groups) + saved_lens = (len(g['params']) for g in saved_groups) + if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)): + raise ValueError("loaded state dict contains a parameter group " + "that doesn't match the size of optimizer's group") + + # Creating mapping from id to parameters. + id_map = { + old_id: p for old_id, p in zip(chain.from_iterable((g['params'] for g in saved_groups + )), chain.from_iterable((g['params'] for g in param_groups))) + } + + # Update parameter groups, setting their 'params' value. + def update_group(group, new_group): + new_group['params'] = group['params'] + return new_group + + updated_groups = [update_group(g, ng) for g, ng in zip(param_groups, saved_groups)] + + optimizer.__dict__.update({'param_groups': updated_groups}) + return id_map + + +def load_states_into_optimizer(optimizer: Optimizer, state_dict: dict, id_map: dict, strict: bool = False): + r"""Copies states from `state_dict` into an Optimizer object. + + Args: + optimizer(Optimizer): An initialized Optimizer object to be loaded + state_dict(dict): A mapping from tensor index (an integer) + to its states to be loaded (a mapping from state name to a tensor). + id_map(dict): A mapping from tensor index (an integer) + to its corresponding parameter (a tensor) whose states will be updated. + strict(bool, optional): If set to True, only load the parameters with its id in id_map. Defaults to False. + """ + + # Ensure that the keys of state_dict are integers. + state_dict = {int(k): v for k, v in state_dict.items()} + + def cast(param, value, key=None): + r"""Make a deep copy of value, casting all tensors to device of param.""" + if isinstance(value, torch.Tensor): + # Floating-point types are a bit special here. They are the only ones + # that are assumed to always match the type of params. + # Make sure state['step'] is not casted https://github.com/pytorch/pytorch/issues/74424 + if (key != "step"): + if param.is_floating_point(): + value = value.to(param.dtype) + value = value.to(param.device) + return value + elif isinstance(value, dict): + return {k: cast(param, v, key=k) for k, v in value.items()} + elif isinstance(value, container_abcs.Iterable): + return type(value)(cast(param, v) for v in value) + else: + return value + + # Copy state assigned to params (and cast tensors to appropriate types). + # State that is not assigned to params is copied as is (needed for + # backward compatibility). + new_states = defaultdict(dict) + for k, v in state_dict.items(): + if k in id_map: + param = id_map[k] + new_states[param] = cast(param, v) + elif not strict: + new_states[k] = v + + optimizer.state.update(new_states) + + +def sharded_optimizer_loading_epilogue(optimizer: Optimizer): + r"""Do the cleaning up work after state_dict has been loaded into optimizer + + Args: + optimizer(Optimizer): An optimizer object whose state has just been loaded. + """ + + # Do the cleaning up as in src code of Pytorch. + optimizer._hook_for_profile() # To support multiprocessing pickle/unpickle. + optimizer.defaults.setdefault('differentiable', False) + + def has_index_file(checkpoint_path: str) -> Tuple[bool, Optional[Path]]: """ Check whether the checkpoint has an index file. diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index 5aff91f03..1c19071fe 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -679,7 +679,7 @@ class ZeroDDP(ColoDDP): gathered_param_buffer.update(self._get_chunk_to_save_data(chunk, only_rank_0, dtype)) gathered_param = gathered_param_buffer.pop(fp32_param) - block, block_size = sharder.append(prefix + name, gathered_param) + block, block_size = sharder.append_param(prefix + name, gathered_param) if block is not None: yield block, block_size @@ -690,7 +690,7 @@ class ZeroDDP(ColoDDP): for name, buf in self.named_buffers(): if buf is not None and name not in self._non_persistent_buffers_set: buffer = buf if keep_vars else buf.detach() - block, block_size = sharder.append(prefix + name, buffer) + block, block_size = sharder.append_param(prefix + name, buffer) if block is not None: yield block, block_size # save extra states @@ -698,7 +698,7 @@ class ZeroDDP(ColoDDP): if getattr(self.__class__, "get_extra_state", torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state: extra_state = self.get_extra_state() - block, block_size = sharder.append(extra_state_key, extra_state) + block, block_size = sharder.append_param(extra_state_key, extra_state) if block is not None: yield block, block_size diff --git a/colossalai/zero/gemini/gemini_optimizer.py b/colossalai/zero/gemini/gemini_optimizer.py index a2085323f..58b0f33ab 100644 --- a/colossalai/zero/gemini/gemini_optimizer.py +++ b/colossalai/zero/gemini/gemini_optimizer.py @@ -10,7 +10,7 @@ from torch.nn import Parameter from torch.optim import Optimizer from colossalai.amp.naive_amp.mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin -from colossalai.checkpoint_io.utils import calculate_tensor_size +from colossalai.checkpoint_io.utils import StateDictSharder from colossalai.logging import get_dist_logger from colossalai.nn.optimizer import ColossalaiOptimizer, CPUAdam, FusedAdam, HybridAdam from colossalai.tensor.d_tensor import is_distributed_tensor @@ -691,49 +691,17 @@ class ZeroOptimizer(ColossalaiOptimizer): Iterator[OrderedDict]: A generator of state dict shard of optimizer states. """ - current_block = {} - current_block_size = 0 - + sharder = StateDictSharder(max_shard_size) for param_id in self.id_to_real_params.keys(): dist.barrier() state = self.collect_states(param_id=param_id, only_rank_0=only_rank_0) - ret_block = None - ret_block_size = 0 + block, block_size = sharder.append_optim_state(param_id, state) + if block is not None: + yield block, block_size - # A state might contain more than one tensors. - # e.g. each Adam state includes: 'step', 'exp_avg', 'exp_avg_sq' - state_size = 0 - isDTensor = False - for state_tensor in state.values(): - - # When state_tensor is not of Tensor class, - # e.g., a SGD optimizer with momentum set to 0 can have None as state - # The calculation of tensor size should be skipped to avoid error. - if not isinstance(state_tensor, torch.Tensor): - continue - - # If the states are stored as DTensors, mark isDTensor as true. - if is_distributed_tensor(state_tensor): - isDTensor = True - state_size += calculate_tensor_size(state_tensor) - - if not isDTensor: - - if current_block_size + state_size > max_shard_size and current_block_size > 0: - ret_block = current_block - ret_block_size = current_block_size - current_block = {} - current_block_size = 0 - - current_block[param_id] = state - current_block_size += state_size - - if ret_block != None: - yield ret_block, ret_block_size - - yield current_block, current_block_size + yield sharder.current_block, sharder.current_block_size class GeminiAdamOptimizer(ZeroOptimizer): diff --git a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py index 67d73c31f..e43908e0c 100644 --- a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py @@ -10,6 +10,7 @@ from colossalai.booster.plugin import HybridParallelPlugin from colossalai.shardformer.layer.utils import Randomizer from colossalai.tensor.d_tensor.api import clear_layout_converter from colossalai.testing import ( + assert_close_loose, check_state_dict_equal, clear_cache_before_run, parameterize, @@ -19,34 +20,34 @@ from colossalai.testing import ( from tests.kit.model_zoo import model_zoo +# TODO (Baizhou): Add test cases for shard=False @clear_cache_before_run() @parameterize('shard', [True]) @parameterize('model_name', ['transformers_gpt']) @parameterize('size_per_shard', [32]) @parameterize('test_config', [{ - 'tp_size': 2, - 'pp_size': 2, - 'num_microbatches': 4, - 'precision': 'fp32', -}, { - 'tp_size': 1, - 'pp_size': 2, - 'num_microbatches': 4, - 'precision': 'fp32', -}, { 'tp_size': 4, 'pp_size': 1, 'precision': 'fp32', }, { 'tp_size': 2, - 'pp_size': 1, - 'precision': 'fp32', + 'pp_size': 2, + 'num_microbatches': 4, + 'precision': 'fp16', + 'initial_scale': 1 }, { 'tp_size': 2, 'pp_size': 1, 'zero_stage': 2, 'precision': 'fp16', 'initial_scale': 1 +}, { + 'tp_size': 1, + 'pp_size': 2, + 'num_microbatches': 4, + 'zero_stage': 1, + 'precision': 'fp16', + 'initial_scale': 1 }]) def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_config: dict): @@ -61,46 +62,91 @@ def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_conf loss = criterion(outputs) return loss + def _preprocess_data(data): + if booster.plugin.stage_manager is not None: + for k, v in data.items(): + if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__: + new_shape = [1] * v.dim() + new_shape[0] = 4 + data[k] = v.to('cuda').repeat(*new_shape) + return iter([data]) + else: + return {k: v.cuda() for k, v in data.items()} + model = model_fn().cuda() optimizer = Adam(model.parameters(), lr=1e-3) model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) - new_model = model_fn().cuda() - new_optimizer = Adam(new_model.parameters(), lr=1e-3) - new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion) - data = data_gen_fn() model.train() if booster.plugin.stage_manager is not None: - for k, v in data.items(): - if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__: - new_shape = [1] * v.dim() - new_shape[0] = 4 - data[k] = v.to('cuda').repeat(*new_shape) - data_iter = iter([data]) - output = booster.execute_pipeline(data_iter, - model, - _criterion, - optimizer, - return_loss=True, - return_outputs=False) + booster.execute_pipeline(_preprocess_data(data), + model, + _criterion, + optimizer, + return_loss=True, + return_outputs=False) else: - data = {k: v.cuda() for k, v in data.items()} - output = model(**data) + output = model(**_preprocess_data(data)) loss = criterion(output) optimizer.backward(loss) optimizer.step() with shared_tempdir() as tempdir: + model_ckpt_path = f"{tempdir}/model" - # optimizer_ckpt_path = f"{tempdir}/optimizer" + optimizer_ckpt_path = f"{tempdir}/optimizer" booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard) - # booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard) + booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard) dist.barrier() + + new_model = model_fn().cuda() + new_optimizer = Adam(new_model.parameters(), lr=1e-3) + new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion) + booster.load_model(new_model, model_ckpt_path) check_state_dict_equal(model.unwrap().state_dict(), new_model.unwrap().state_dict(), False) + booster.load_optimizer(new_optimizer, optimizer_ckpt_path) + check_state_dict_equal(optimizer.unwrap().state_dict(), new_optimizer.unwrap().state_dict(), False) + dist.barrier() + # Check whether the loaded model & optimizer works smoothly. + model.train() + new_model.train() + if booster.plugin.stage_manager is not None: + booster.execute_pipeline(_preprocess_data(data), + model, + _criterion, + optimizer, + return_loss=True, + return_outputs=False) + booster.execute_pipeline(_preprocess_data(data), + new_model, + _criterion, + new_optimizer, + return_loss=True, + return_outputs=False) + else: + old_model_loss = criterion(model(**_preprocess_data(data))) + optimizer.backward(old_model_loss) + new_model_loss = criterion(new_model(**_preprocess_data(data))) + new_optimizer.backward(new_model_loss) + + optimizer.step() + new_optimizer.step() + + # Check updated weights. + stage_manager = booster.plugin.stage_manager + + if stage_manager is None or stage_manager.is_first_stage(): + assert_close_loose(model.unwrap().wte.weight.data, new_model.unwrap().wte.weight.data, atol=5e-3, rtol=5e-3) + assert_close_loose(model.unwrap().h[0].mlp.c_fc.weight.data, + new_model.unwrap().h[0].mlp.c_fc.weight.data, + atol=5e-3, + rtol=5e-3) + + dist.barrier() Randomizer.reset_index() clear_layout_converter() From 38ccb8b1a321fa70926236a22cfd7911a993b53e Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Fri, 1 Sep 2023 17:40:01 +0800 Subject: [PATCH 25/33] [shardformer] support from_pretrained when loading model with HybridParallelPlugin (#4575) * hybrid plugin support huggingface from_pretrained * add huggingface compatibility tests * add folder cleaning * fix bugs --- .github/workflows/build_on_pr.yml | 2 +- .../booster/plugin/hybrid_parallel_plugin.py | 4 +- .../hybrid_parallel_checkpoint_io.py | 19 ++- colossalai/checkpoint_io/utils.py | 81 ++++++++++- .../test_hybrid_huggingface_compatibility.py | 129 ++++++++++++++++++ 5 files changed, 218 insertions(+), 17 deletions(-) create mode 100644 tests/test_checkpoint_io/test_hybrid_huggingface_compatibility.py diff --git a/.github/workflows/build_on_pr.yml b/.github/workflows/build_on_pr.yml index 4c7e08e57..3f91dc33a 100644 --- a/.github/workflows/build_on_pr.yml +++ b/.github/workflows/build_on_pr.yml @@ -208,7 +208,7 @@ jobs: - name: Execute Unit Testing run: | - CURL_CA_BUNDLE="" PYTHONPATH=$PWD pytest -m "not largedist" --testmon --testmon-cov=. --durations=10 tests/ + CURL_CA_BUNDLE="" PYTHONPATH=$PWD pytest -m "not largedist" --testmon --testmon-forceselect --testmon-cov=. --durations=10 tests/ env: DATA: /data/scratch/cifar-10 NCCL_SHM_DISABLE: 1 diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 277843b66..eced4fc1a 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -141,10 +141,10 @@ def get_param_info(optim: Optimizer): def init_pipeline_optimizer(optim: Optimizer, model: Module): - params = set(model.parameters()) + model_params = set(model.parameters()) new_param_groups = [] for group in optim.param_groups: - params = [p for p in group['params'] if p in params] + params = [p for p in group['params'] if p in model_params] new_param_groups.append({**group, 'params': params}) optim.__setstate__({'param_groups': new_param_groups}) diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index c128858b1..fef5b0d16 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -26,6 +26,7 @@ from .utils import ( load_shard_state_dict, load_state_dict_into_model, load_states_into_optimizer, + save_config_file, save_param_groups, save_state_dict_shards, search_tp_partition_dim, @@ -204,6 +205,7 @@ class HypridParallelCheckpointIO(GeneralCheckpointIO): if control_saving: index_file.append_meta_data("total_size", total_size) index_file.write_index_file(save_index_file) + save_config_file(model, checkpoint) if self.verbose: logging.info(f"The model is split into checkpoint shards. " f"You can find where each parameters has been saved in the " @@ -219,9 +221,9 @@ class HypridParallelCheckpointIO(GeneralCheckpointIO): Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True) # Manage filenames of sharded weights and index file for each pipeline stage. - weights_name = weights_name.replace(".bin", f"-stage-{self.pp_rank:05d}-shard.bin") - weights_name = weights_name.replace(".safetensors", f"-stage-{self.pp_rank:05d}-shard.safetensors") - save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank:05d}.json") + weights_name = weights_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-shard.bin") + weights_name = weights_name.replace(".safetensors", f"-stage-{self.pp_rank+1:05d}-shard.safetensors") + save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json") save_index_file = os.path.join("tmp_index_files", save_index_file) total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard, @@ -229,7 +231,8 @@ class HypridParallelCheckpointIO(GeneralCheckpointIO): index_file=index_file, base_filename=weights_name, is_master=control_saving, - use_safetensors=use_safetensors) + use_safetensors=use_safetensors, + use_pp_format=True) if control_saving: assert self.dp_rank == 0 and self.tp_rank == 0, "The saving process should have both dp_rank and tp_rank as 0." index_file.append_meta_data("total_size", total_size) @@ -251,6 +254,7 @@ class HypridParallelCheckpointIO(GeneralCheckpointIO): final_index_file.append_weight_map(weight, weight_filename) final_index_file.write_index_file(final_index_file_path) + save_config_file(model, checkpoint) rmtree(tmp_index_file_folder) if self.verbose: logging.info(f"The model is split into checkpoint shards. " @@ -423,15 +427,16 @@ class HypridParallelCheckpointIO(GeneralCheckpointIO): Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True) # Manage filenames of sharded weights and index file for each pipeline stage. - states_name = states_name.replace(".bin", f"-stage-{self.pp_rank:05d}-shard.bin") - save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank:05d}.json") + states_name = states_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-shard.bin") + save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json") save_index_file = os.path.join("tmp_index_files", save_index_file) total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard, checkpoint=checkpoint, index_file=index_file, base_filename=states_name, - is_master=control_saving) + is_master=control_saving, + use_pp_format=True) if control_saving: assert self.dp_rank == 0 and self.tp_rank == 0, "The saving process should have both dp_rank and tp_rank as 0." diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index 0025d07df..0300e6265 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -9,12 +9,12 @@ from pathlib import Path from typing import Iterator, List, Mapping, Optional, OrderedDict, Tuple import torch -import torch.distributed as dist import torch.nn as nn -from torch.distributed import ProcessGroup from torch.optim import Optimizer +from transformers.modeling_utils import PreTrainedModel, get_parameter_dtype +from transformers.modeling_utils import unwrap_model as unwrap_huggingface_model -from colossalai.interface import OptimizerWrapper +from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.nn.optimizer import ColossalaiOptimizer from colossalai.tensor.d_tensor import ( is_customized_distributed_tensor, @@ -228,7 +228,8 @@ def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]] index_file: "CheckpointIndexFile", base_filename: str, is_master: bool, - use_safetensors: bool = False) -> int: + use_safetensors: bool = False, + use_pp_format: bool = False) -> int: ''' Save sharded state dict only on master rank, this method can be used by both model and optimizer states. Args: @@ -236,14 +237,16 @@ def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]] checkpoint (str): The path of checkpoint directory as string. index_file (CheckpointIndexFile): The index file object to be updated. base_filename (str): Decides the prefix of filenames of shards. - is_master (bool): Whether current rank is master. - use_safetensors (bool): Whether to use safetensors to save checkpoint. + is_master (bool): Whether current rank is main process. + use_safetensors (bool, optional): Whether to use safetensors to save checkpoint. Defaults to False. + use_pp_format: (bool, optional): Whether to save the files in pipeline format including stage information. Defaults to False. Returns: int: the total size of shards ''' total_size = 0 + shard_filenames = [] for idx, shard_pair in enumerate(sharded_state_dict): shard, current_size = shard_pair if not is_master: @@ -257,8 +260,12 @@ def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]] # Only save on master rank. save_state_dict(shard, checkpoint_file_path, use_safetensors=use_safetensors) + shard_filenames.append(shard_file) del shard + # Clean folder, deleted unneeded files. + clean_folder(checkpoint, base_filename, shard_filenames, is_master=is_master, use_pp_format=use_pp_format) + return total_size @@ -335,6 +342,66 @@ def save_param_groups(state_dict: dict, group_file_path: str) -> None: torch.save(param_groups, group_file_path) +def clean_folder(checkpoint_path: str, + weights_name: str, + shard_filenames: List[str], + is_master: bool = True, + use_pp_format: bool = False): + """ + Clean the unneeded files in checkpoint directory after shards of state_dict have been saved. + + Args: + checkpoint_path (str): Path to the checkpoint directory. + weights_name (str): Decides the prefix of filenames of weight shards. + shard_filenames (List[str]): The list of saved shard filenames which should not be removed. + is_master (bool, optional): Whether current rank is main process. Defaults to True. + use_pp_format: (bool, optional): Whether to save the files in pipeline format including stage information. Defaults to False. + + """ + if is_master: + for filename in os.listdir(checkpoint_path): + full_filename = os.path.join(checkpoint_path, filename) + weights_no_suffix = weights_name.replace(".bin", "").replace(".safetensors", "") + filename_no_suffix = filename.replace(".bin", "").replace(".safetensors", "") + if not use_pp_format: + reg = re.compile(r"(.*?)-\d{5}") + else: + # When this checkpoint is created by pipeline parallel process, the pattern is a little different. + reg = re.compile(r"(.*?)-stage-\d{5}-shard-\d{5}") + if (filename.startswith(weights_no_suffix) and os.path.isfile(full_filename) + and filename not in shard_filenames and reg.fullmatch(filename_no_suffix) is not None): + os.remove(full_filename) + + +def save_config_file(model: nn.Module, checkpoint_path: str, is_master: bool = True): + """ + Save config.json/generation_config.json if model is a Huggingface pretrained model. + This method can only be called when a model is saved in a sharded way. + + Args: + model (nn.Module): The model whose config should be saved if it's a huggingface model. + checkpoint_path (str): Path to the checkpoint directory. + is_master (bool): Whether current rank is main process. + """ + if not isinstance(model, PreTrainedModel): + return + + model = unwrap_huggingface_model(model) + + # save the string version of dtype to the config, e.g. convert torch.float32 => "float32" + dtype = get_parameter_dtype(model) + model.config.torch_dtype = str(dtype).split(".")[1] + + # Attach architecture to the config + model.config.architectures = [model.__class__.__name__] + + # Save the config + if is_master: + model.config.save_pretrained(checkpoint_path) + if model.can_generate(): + model.generation_config.save_pretrained(checkpoint_path) + + def save_dtensor(name: str, tensor: torch.Tensor, index_file: "CheckpointIndexFile", use_safetensors: bool) -> None: """ Save distributed tensor to checkpoint. This checkpoint will be a dictionary which contains @@ -709,5 +776,5 @@ def get_shard_filename(weights_name: str, idx: int): get shard file name """ shard_file = weights_name.replace(".bin", f"-{idx+1:05d}.bin") - shard_file = shard_file.replace(".safetensors", f"-{idx + 1:05d}.safetensors") + shard_file = shard_file.replace(".safetensors", f"-{idx+1:05d}.safetensors") return shard_file diff --git a/tests/test_checkpoint_io/test_hybrid_huggingface_compatibility.py b/tests/test_checkpoint_io/test_hybrid_huggingface_compatibility.py new file mode 100644 index 000000000..df907605d --- /dev/null +++ b/tests/test_checkpoint_io/test_hybrid_huggingface_compatibility.py @@ -0,0 +1,129 @@ +import pytest +import torch +import torch.distributed as dist +from torch.optim import Adam +from utils import shared_tempdir + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import HybridParallelPlugin +from colossalai.shardformer.layer.utils import Randomizer +from colossalai.tensor.d_tensor.api import clear_layout_converter +from colossalai.testing import ( + check_state_dict_equal, + clear_cache_before_run, + parameterize, + rerun_if_address_is_in_use, + spawn, +) +from tests.kit.model_zoo import model_zoo + + +def exam_from_pretrained(model_fn, + data_gen_fn, + output_transform_fn, + loss_fn, + test_config, + shard=True, + size_per_shard=32): + + def _criterion(outputs, inputs): + outputs = output_transform_fn(outputs) + loss = criterion(outputs) + return loss + + def _preprocess_data(data): + if booster.plugin.stage_manager is not None: + for k, v in data.items(): + if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__: + new_shape = [1] * v.dim() + new_shape[0] = 4 + data[k] = v.to('cuda').repeat(*new_shape) + return iter([data]) + else: + return {k: v.cuda() for k, v in data.items()} + + model = model_fn() + optimizer = Adam((model.parameters()), lr=0.001) + criterion = loss_fn + plugin = HybridParallelPlugin(**test_config) + booster = Booster(plugin=plugin) + + model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) + + data = data_gen_fn() + model.train() + if booster.plugin.stage_manager is not None: + booster.execute_pipeline(_preprocess_data(data), + model, + _criterion, + optimizer, + return_loss=True, + return_outputs=False) + else: + output = model(**_preprocess_data(data)) + loss = criterion(output) + optimizer.backward(loss) + + optimizer.step() + + with shared_tempdir() as tempdir: + + model_ckpt_path = f"{tempdir}/model" + booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard) + dist.barrier() + + new_model = model.unwrap().__class__.from_pretrained(model_ckpt_path) + new_optimizer = Adam(new_model.parameters(), lr=1e-3) + new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion) + + check_state_dict_equal(model.unwrap().state_dict(), new_model.unwrap().state_dict(), False) + + Randomizer.reset_index() + torch.cuda.empty_cache() + + +@clear_cache_before_run() +@parameterize('test_config', [{ + 'tp_size': 4, + 'pp_size': 1, + 'precision': 'fp32', +}, { + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 4, + 'precision': 'fp16', + 'initial_scale': 1 +}, { + 'tp_size': 2, + 'pp_size': 1, + 'zero_stage': 2, + 'precision': 'fp16', + 'initial_scale': 1 +}, { + 'tp_size': 1, + 'pp_size': 2, + 'num_microbatches': 4, + 'zero_stage': 1, + 'precision': 'fp16', + 'initial_scale': 1 +}]) +def run_test(test_config): + sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt') + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + exam_from_pretrained(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + clear_layout_converter() + torch.cuda.empty_cache() + + +def run_dist(rank, world_size, port): + config = {} + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_test() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [4]) +@rerun_if_address_is_in_use() +def test_huggingface_compatibility(world_size): + spawn(run_dist, world_size) From 508ca36fe37a8d9434647d224757e06833ed6557 Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Fri, 1 Sep 2023 21:45:14 +0800 Subject: [PATCH 26/33] [pipeline] 1f1b schedule receive microbatch size (#4589) --- .../booster/plugin/hybrid_parallel_plugin.py | 8 +++++- colossalai/pipeline/schedule/one_f_one_b.py | 27 +++++++++++++++---- .../test_schedule/test_oneF_oneB.py | 2 +- 3 files changed, 30 insertions(+), 7 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index eced4fc1a..c83e51b26 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -247,6 +247,9 @@ class HybridParallelPlugin(PipelinePluginBase): enable_flash_attention (bool, optional): Whether to switch on flash attention. Defaults to False. enable_jit_fused (bool, optional): Whether to switch on JIT. Default to Falase. num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None. + microbatch_size (int, optional): Microbatch size when using pipeline parallelism. + Either ``num_microbatches`` or ``microbatch_size`` should be provided if using pipeline. + If ``num_microbatches`` is provided, this will be ignored. Defaults to None. initial_scale (float, optional): The initial loss scale of AMP. Defaults to 2**16. min_scale (float, optional): The minimum loss scale of AMP. Defaults to 1. growth_factor (float, optional): The multiplication factor for increasing loss scale when using AMP. Defaults to 2. @@ -278,6 +281,7 @@ class HybridParallelPlugin(PipelinePluginBase): enable_jit_fused: bool = False, enable_sequence_parallelism: bool = False, num_microbatches: Optional[int] = None, + microbatch_size: Optional[int] = None, initial_scale: float = 2**16, min_scale: float = 1, growth_factor: float = 2, @@ -324,7 +328,9 @@ class HybridParallelPlugin(PipelinePluginBase): assert num_microbatches is not None, 'num_microbatches must be specified when using pipeline parallelism' assert self.zero_stage <= 1, 'zero stage must be 0 or 1 when using pipeline parallelism' self.stage_manager = PipelineStageManager(self.pg_mesh, PP_AXIS) - self.schedule = OneForwardOneBackwardSchedule(num_microbatches, self.stage_manager) + self.schedule = OneForwardOneBackwardSchedule(self.stage_manager, + num_microbatches=num_microbatches, + microbatch_size=microbatch_size) self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS) self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS) diff --git a/colossalai/pipeline/schedule/one_f_one_b.py b/colossalai/pipeline/schedule/one_f_one_b.py index 0058873c2..11b2655a2 100644 --- a/colossalai/pipeline/schedule/one_f_one_b.py +++ b/colossalai/pipeline/schedule/one_f_one_b.py @@ -17,14 +17,26 @@ from .base import PipelineSchedule class OneForwardOneBackwardSchedule(PipelineSchedule): - def __init__(self, num_microbatches: int, stage_manager: PipelineStageManager) -> None: + def __init__(self, + stage_manager: PipelineStageManager, + num_microbatches: Optional[int] = None, + microbatch_size: Optional[int] = None) -> None: + """1F1B pipeline schedule. + + Args: + stage_manager (PipelineStageManager): Pipeline stage manager + num_microbatches (Optional[int], optional): The number of microbatches. If not provided, it will be derived from microbatch size. Defaults to None. + microbatch_size (Optional[int], optional): Microbatch size. If num_microbatches is provided, this will be ignored. Defaults to None. + """ super().__init__(stage_manager) + assert num_microbatches is not None or microbatch_size is not None, \ + "Either num_microbatches or microbatch_size should be provided" self.comm = PipelineP2PCommunication(stage_manager) self.num_microbatches = num_microbatches + self.microbatch_size = microbatch_size self.batch: Optional[Any] = None self.batch_size: Optional[int] = None self.microbatch_offset: Optional[int] = None - self.microbatch_size: Optional[int] = None def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None: """Load a batch from data iterator. @@ -39,9 +51,14 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): self.batch = batch self.batch_size = get_batch_size(batch) self.microbatch_offset = 0 - assert self.batch_size % self.num_microbatches == 0, \ - "Batch size should divided by the number of microbatches" - self.microbatch_size = self.batch_size // self.num_microbatches + if self.num_microbatches is not None: + assert self.batch_size % self.num_microbatches == 0, \ + "Batch size should divided by the number of microbatches" + self.microbatch_size = self.batch_size // self.num_microbatches + else: + assert self.batch_size % self.microbatch_size == 0, \ + "Batch size should divided by the microbatch size" + self.num_microbatches = self.batch_size // self.microbatch_size def load_micro_batch(self) -> Any: """Load a micro batch from the current batch. diff --git a/tests/test_pipeline/test_schedule/test_oneF_oneB.py b/tests/test_pipeline/test_schedule/test_oneF_oneB.py index 542116a1d..d31eafd70 100644 --- a/tests/test_pipeline/test_schedule/test_oneF_oneB.py +++ b/tests/test_pipeline/test_schedule/test_oneF_oneB.py @@ -61,7 +61,7 @@ def examine_pp(): DP_DIM, PP_DIM, TP_DIM = 0, 1, 2 pg_mesh = ProcessGroupMesh(1, world_size, 1) stage_manager = PipelineStageManager(pg_mesh, PP_DIM) - schedule = OneForwardOneBackwardSchedule(NUM_MICRO_BATCHS, stage_manager) + schedule = OneForwardOneBackwardSchedule(stage_manager, num_microbatches=NUM_MICRO_BATCHS) for idx, (_, sub_model) in enumerate(pp_model.named_children()): if idx % (world_size) == local_rank: From 24c076879558133d66ffcb6111f9bccaa23f6017 Mon Sep 17 00:00:00 2001 From: Jianghai <72591262+CjhHa1@users.noreply.github.com> Date: Mon, 4 Sep 2023 17:52:23 +0800 Subject: [PATCH 27/33] [shardformer] Pytree fix (#4533) * pytree test * test bert * test bert * test bert * revise * add register * add register --- colossalai/pipeline/schedule/_utils.py | 62 +++++++++++++++++-- colossalai/pipeline/schedule/one_f_one_b.py | 19 ++++-- colossalai/shardformer/policies/chatglm2.py | 5 ++ tests/test_shardformer/test_model/_utils.py | 11 +--- .../test_model/test_shard_bert.py | 1 + 5 files changed, 81 insertions(+), 17 deletions(-) diff --git a/colossalai/pipeline/schedule/_utils.py b/colossalai/pipeline/schedule/_utils.py index 5cd934b76..583558551 100644 --- a/colossalai/pipeline/schedule/_utils.py +++ b/colossalai/pipeline/schedule/_utils.py @@ -1,9 +1,59 @@ -from typing import Any, List, Optional +from collections import OrderedDict +from typing import Any, List, Optional, Tuple import torch import torch.cuda from torch.nn import Module -from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten +from torch.utils._pytree import ( + SUPPORTED_NODES, + LeafSpec, + TreeSpec, + _is_leaf, + _register_pytree_node, + tree_flatten, + tree_map, + tree_unflatten, +) + + +# this register are for torch under version 1.13.1, maybe removed in the future +def _odict_flatten(d: 'OrderedDict[Any, Any]') -> Tuple[List[Any], Any]: + return list(d.values()), list(d.keys()) + + +def _odict_unflatten(values: List[Any], context: Any) -> 'OrderedDict[Any, Any]': + return OrderedDict((key, value) for key, value in zip(context, values)) + + +_register_pytree_node(OrderedDict, _odict_flatten, _odict_unflatten) + + +def tree_map_hf(fn: Any, pytree: Any): + flat_args, spec = tree_flatten_hf(pytree) + return tree_unflatten([fn(i) for i in flat_args], spec) + + +# use this flatten function to handle the ModelingOutput Class instance. +def tree_flatten_hf(pytree: Any) -> Tuple[List[Any], TreeSpec]: + """Flattens a pytree into a list of values an a TreeSpec that can be used + to reconstruct the pytree. + """ + if isinstance(pytree, OrderedDict): + node_type = OrderedDict + flatten_fn = SUPPORTED_NODES[node_type].flatten_fn + child_pytrees, context = flatten_fn(pytree) + + # Recursively flatten the children + result: List[Any] = [] + children_specs: List['TreeSpec'] = [] + for child in child_pytrees: + flat, child_spec = tree_flatten_hf(child) + result += flat + children_specs.append(child_spec) + return result, TreeSpec(node_type, context, children_specs) + else: + result, tree_spec = tree_flatten(pytree) + return result, tree_spec def to_device(x: Any, device: Optional[torch.device] = None) -> Any: @@ -104,7 +154,7 @@ def detach(x: Any) -> Any: return x -def merge_batch(data: List[Any]) -> Any: +def merge_batch(data: List[Any], batch_size_dim=0) -> Any: """Merge micro batches into a batch. Args: @@ -118,15 +168,17 @@ def merge_batch(data: List[Any]) -> Any: flattened_data = [] tree_spec = None for d in data: - elems, tree_spec = tree_flatten(d) + # elems should be an instance of OrderedDict + elems, tree_spec = tree_flatten_hf(d) flattened_data.append(elems) merged_data = [] + for elem_batch in zip(*flattened_data): if isinstance(elem_batch[0], torch.Tensor): if len(elem_batch[0].shape) == 0: # set loss to None in pipeline outputs merged_data.append(None) else: - merged_data.append(torch.cat(elem_batch, dim=0)) + merged_data.append(torch.cat(elem_batch, dim=batch_size_dim)) else: merged_data.append(list(elem_batch)) return tree_unflatten(merged_data, tree_spec) diff --git a/colossalai/pipeline/schedule/one_f_one_b.py b/colossalai/pipeline/schedule/one_f_one_b.py index 11b2655a2..ec53a6771 100644 --- a/colossalai/pipeline/schedule/one_f_one_b.py +++ b/colossalai/pipeline/schedule/one_f_one_b.py @@ -6,12 +6,21 @@ import torch.cuda from torch.nn import Module from torch.utils._pytree import tree_map -from colossalai.interface import OptimizerWrapper +from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.pipeline.p2p import PipelineP2PCommunication from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.utils.cuda import get_current_device -from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device +from ._utils import ( + detach, + get_batch_size, + get_micro_batch, + merge_batch, + model_forward, + retain_grad, + to_device, + tree_map_hf, +) from .base import PipelineSchedule @@ -154,7 +163,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): if accum_loss is not None: accum_loss.add_(loss.detach()) if outputs is not None: - outputs.append(tree_map(detach, output_obj)) + outputs.append(tree_map_hf(detach, output_obj)) return loss else: return output_obj @@ -302,5 +311,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): self.send_backward(input_obj_grad) if outputs is not None: - outputs = merge_batch(outputs) + if isinstance(model, ModelWrapper): + model = model.unwrap() + outputs = merge_batch(outputs, getattr(model, 'batch_size_dim', 0)) return {'loss': accum_loss, 'outputs': outputs} diff --git a/colossalai/shardformer/policies/chatglm2.py b/colossalai/shardformer/policies/chatglm2.py index 5bcbc2acc..448988470 100644 --- a/colossalai/shardformer/policies/chatglm2.py +++ b/colossalai/shardformer/policies/chatglm2.py @@ -41,6 +41,11 @@ class ChatGLMPolicy(Policy): new_vocab_size = vocab_size + world_size - vocab_size % world_size self.model.resize_token_embeddings(new_vocab_size) + if self.pipeline_stage_manager is not None: + # the batch_size_dim is bounded to Model + bsz_dim = 1 + setattr(self.model, 'batch_size_dim', bsz_dim) + return self.model def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index 72bb2b025..f77bf7495 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -191,15 +191,10 @@ def check_output_hidden_state(org_output: Tensor, org_hidden_state = org_output.last_hidden_state - if stage_manager is None: - sharded_hidden_state = sharded_output.last_hidden_state - if stage_manager and stage_manager.is_last_stage(): - pipeline_output = sharded_output['outputs'] - if isinstance(pipeline_output, List): - sharded_hidden_state = torch.cat([output.last_hidden_state for output in pipeline_output], dim=dim) - else: - sharded_hidden_state = pipeline_output.last_hidden_state + sharded_hidden_state = sharded_output['outputs']['last_hidden_state'] + else: + sharded_hidden_state = sharded_output.last_hidden_state assert torch.allclose(org_hidden_state.float(), sharded_hidden_state.float(), atol=atol, rtol=rtol), \ f"shard model's output hidden state is not equal to origin model's last hidden state\n{org_hidden_state}\n{sharded_hidden_state}" diff --git a/tests/test_shardformer/test_model/test_shard_bert.py b/tests/test_shardformer/test_model/test_shard_bert.py index 0855e2248..c779e4170 100644 --- a/tests/test_shardformer/test_model/test_shard_bert.py +++ b/tests/test_shardformer/test_model/test_shard_bert.py @@ -179,6 +179,7 @@ def run_bert_3d_test(test_config): sub_model_zoo = model_zoo.get_sub_registry('transformers_bert') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) clear_layout_converter() From 0a94fcd3514a6f7d4f287bba614fda3fb12c8802 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Mon, 4 Sep 2023 21:46:29 +0800 Subject: [PATCH 28/33] [shardformer] update bert finetune example with HybridParallelPlugin (#4584) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [shardformer] fix opt test hanging * fix * test * test * test * fix test * fix test * remove print * add fix * [shardformer] add bert finetune example * [shardformer] add bert finetune example * [shardformer] add bert finetune example * [shardformer] add bert finetune example * [shardformer] add bert finetune example * [shardformer] add bert finetune example * [shardformer] fix epoch change * [shardformer] broadcast add pp group * [shardformer] fix opt test hanging * fix * test * test * [shardformer] zero1+pp and the corresponding tests (#4517) * pause * finish pp+zero1 * Update test_shard_vit.py * [shardformer/fix overlap bug] fix overlap bug, add overlap as an option in shardco… (#4516) * fix overlap bug and support bert, add overlap as an option in shardconfig * support overlap for chatglm and bloom * [shardformer] fix emerged bugs after updating transformers (#4526) * test * fix test * fix test * remove print * add fix * [shardformer] add bert finetune example * [shardformer] add bert finetune example * [shardformer] Add overlap support for gpt2 (#4535) * add overlap support for gpt2 * remove unused code * remove unused code * [shardformer] support pp+tp+zero1 tests (#4531) * [shardformer] fix opt test hanging * fix * test * test * test * fix test * fix test * remove print * add fix * [shardformer] pp+tp+zero1 [shardformer] pp+tp+zero1 [shardformer] pp+tp+zero1 [shardformer] pp+tp+zero1 [shardformer] pp+tp+zero1 [shardformer] pp+tp+zero1 * [shardformer] pp+tp+zero1 * [shardformer] pp+tp+zero1 * [shardformer] pp+tp+zero1 * [shardformer] pp+tp+zero1 * [shardformer] fix submodule replacement bug when enabling pp (#4544) * [shardformer] support sharded optimizer checkpointIO of HybridParallelPlugin (#4540) * implement sharded optimizer saving * add more param info * finish implementation of sharded optimizer saving * fix bugs in optimizer sharded saving * add pp+zero test * param group loading * greedy loading of optimizer * fix bug when loading * implement optimizer sharded saving * add optimizer test & arrange checkpointIO utils * fix gemini sharding state_dict * add verbose option * add loading of master params * fix typehint * fix master/working mapping in fp16 amp * [shardformer] add bert finetune example * [shardformer] add bert finetune example * [shardformer] add bert finetune example * [shardformer] add bert finetune example * [shardformer] fix epoch change * [shardformer] broadcast add pp group * rebase feature/shardformer * update pipeline * [shardformer] fix * [shardformer] fix * [shardformer] bert finetune fix * [shardformer] add all_reduce operation to loss add all_reduce operation to loss * [shardformer] make compatible with pytree. make compatible with pytree. * [shardformer] disable tp disable tp * [shardformer] add 3d plugin to ci test * [shardformer] update num_microbatches to None * [shardformer] update microbatchsize * [shardformer] update assert * update scheduler * update scheduler --------- Co-authored-by: Jianghai <72591262+CjhHa1@users.noreply.github.com> Co-authored-by: Bin Jia <45593998+FoolPlayer@users.noreply.github.com> Co-authored-by: Baizhou Zhang --- .../booster/plugin/hybrid_parallel_plugin.py | 2 +- colossalai/pipeline/schedule/one_f_one_b.py | 3 +- examples/language/bert/finetune.py | 161 ++++++++++++++---- examples/language/bert/test_ci.sh | 2 +- 4 files changed, 133 insertions(+), 35 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index c83e51b26..8ad9b7956 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -325,7 +325,7 @@ class HybridParallelPlugin(PipelinePluginBase): self.schedule = None assert zero_stage in (0, 1, 2) if self.pp_size > 1: - assert num_microbatches is not None, 'num_microbatches must be specified when using pipeline parallelism' + assert num_microbatches is not None or microbatch_size is not None, 'num_microbatches or microbatch_size must be specified when using pipeline parallelism' assert self.zero_stage <= 1, 'zero stage must be 0 or 1 when using pipeline parallelism' self.stage_manager = PipelineStageManager(self.pg_mesh, PP_AXIS) self.schedule = OneForwardOneBackwardSchedule(self.stage_manager, diff --git a/colossalai/pipeline/schedule/one_f_one_b.py b/colossalai/pipeline/schedule/one_f_one_b.py index ec53a6771..5db1c7f30 100644 --- a/colossalai/pipeline/schedule/one_f_one_b.py +++ b/colossalai/pipeline/schedule/one_f_one_b.py @@ -46,6 +46,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): self.batch: Optional[Any] = None self.batch_size: Optional[int] = None self.microbatch_offset: Optional[int] = None + self._use_microbatch_size = num_microbatches is None def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None: """Load a batch from data iterator. @@ -60,7 +61,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): self.batch = batch self.batch_size = get_batch_size(batch) self.microbatch_offset = 0 - if self.num_microbatches is not None: + if not self._use_microbatch_size: assert self.batch_size % self.num_microbatches == 0, \ "Batch size should divided by the number of microbatches" self.microbatch_size = self.batch_size // self.num_microbatches diff --git a/examples/language/bert/finetune.py b/examples/language/bert/finetune.py index b209ffde8..b9a3d5753 100644 --- a/examples/language/bert/finetune.py +++ b/examples/language/bert/finetune.py @@ -1,12 +1,14 @@ import argparse -from typing import List, Union +from contextlib import nullcontext +from typing import Callable, List, Union import evaluate import torch import torch.distributed as dist import torch.nn as nn from data import GLUEDataBuilder -from torch.optim import Optimizer +from torch.optim import Adam, Optimizer +from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.utils.data import DataLoader from tqdm import tqdm from transformers import ( @@ -18,8 +20,9 @@ from transformers import ( import colossalai from colossalai.booster import Booster -from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin +from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.cluster import DistCoordinator +from colossalai.lazy import LazyInitContext from colossalai.nn.optimizer import HybridAdam from colossalai.utils import get_current_device @@ -32,14 +35,26 @@ LEARNING_RATE = 2.4e-5 WEIGHT_DECAY = 0.01 WARMUP_FRACTION = 0.1 +output_transform_fn = lambda x: x +criterion = lambda x: x.loss + def move_to_cuda(batch): return {k: v.cuda() for k, v in batch.items()} @torch.no_grad() -def evaluate_model(model: nn.Module, test_dataloader: Union[DataLoader, List[DataLoader]], num_labels: int, task_name: str, - eval_splits: List[str], coordinator: DistCoordinator): +def evaluate_model( + model: nn.Module, + optimizer, + criterion, + test_dataloader: Union[DataLoader, List[DataLoader]], + num_labels: int, + task_name: str, + eval_splits: List[str], + booster: Booster, + coordinator: DistCoordinator, +): metric = evaluate.load("glue", task_name, process_id=coordinator.rank, num_process=coordinator.world_size) model.eval() @@ -47,23 +62,66 @@ def evaluate_model(model: nn.Module, test_dataloader: Union[DataLoader, List[Dat accum_loss = torch.zeros(1, device=get_current_device()) for batch in dataloader: batch = move_to_cuda(batch) - outputs = model(**batch) - val_loss, logits = outputs[:2] - accum_loss.add_(val_loss) - - if num_labels > 1: - preds = torch.argmax(logits, axis=1) - elif num_labels == 1: - preds = logits.squeeze() - labels = batch["labels"] + batch_size = batch["input_ids"].shape[0] + if hasattr(booster.plugin, "stage_manager") and booster.plugin.stage_manager is not None: + pg_mesh = booster.plugin.pg_mesh + pp_group = booster.plugin.pp_group + current_pp_group_ranks = pg_mesh.get_ranks_in_group(pp_group) + current_rank = dist.get_rank() + #TODO pass dataloader to execute_pipeline directly + batch = iter([batch]) + outputs = booster.execute_pipeline(batch, + model, + criterion, + optimizer, + return_loss=True, + return_outputs=True) - metric.add_batch(predictions=preds, references=labels) + if booster.plugin.stage_manager.is_last_stage(): + val_loss = outputs["loss"] + + logits = outputs["outputs"]["logits"] + + accum_loss.add_(val_loss) + + if num_labels > 1: + preds = torch.argmax(logits, axis=1) + elif num_labels == 1: + preds = logits.squeeze() + + dist.broadcast(preds, src=current_rank, group=pp_group) + dist.broadcast(val_loss, src=current_rank, group=pp_group) + + metric.add_batch(predictions=preds, references=labels) + elif current_rank in current_pp_group_ranks: + val_loss = torch.empty((1,), device=get_current_device()) + preds = torch.empty((batch_size,), dtype=torch.int64, device=get_current_device()) + + dist.broadcast(preds, src=current_pp_group_ranks[-1], group=pp_group) + dist.broadcast(val_loss, src=current_pp_group_ranks[-1], group=pp_group) + + accum_loss.add_(val_loss) + metric.add_batch(predictions=preds, references=labels) + + else: + batch = move_to_cuda(batch) + outputs = model(**batch) + val_loss, logits = outputs[:2] + accum_loss.add_(val_loss) + + if num_labels > 1: + preds = torch.argmax(logits, axis=1) + elif num_labels == 1: + preds = logits.squeeze() + + metric.add_batch(predictions=preds, references=labels) results = metric.compute() dist.all_reduce(accum_loss.div_(len(dataloader))) - if coordinator.is_master(): + if coordinator.is_master() and results is not None: results['loss'] = accum_loss.item() / coordinator.world_size + return results if isinstance(test_dataloader, DataLoader): @@ -77,25 +135,43 @@ def evaluate_model(model: nn.Module, test_dataloader: Union[DataLoader, List[Dat return final_results -def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, lr_scheduler, train_dataloader: DataLoader, - booster: Booster, coordinator: DistCoordinator): +def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, _criterion: Callable, lr_scheduler: LRScheduler, + train_dataloader: DataLoader, booster: Booster, coordinator: DistCoordinator): + model.train() - with tqdm(train_dataloader, desc=f'Epoch [{epoch + 1}/{NUM_EPOCHS}]', disable=not coordinator.is_master()) as pbar: + is_pp_last_stage = hasattr( + booster.plugin, + "stage_manager") and booster.plugin.stage_manager is not None and booster.plugin.stage_manager.is_last_stage() + with tqdm(train_dataloader, + desc=f'Epoch [{epoch + 1}/{NUM_EPOCHS}]', + disable=not (coordinator.is_master() or is_pp_last_stage)) as pbar: for batch in pbar: # Forward pass batch = move_to_cuda(batch) - outputs = model(**batch) - loss = outputs[0] + if hasattr(booster.plugin, "stage_manager") and booster.plugin.stage_manager is not None: + #TODO pass train_dataloader to execute_pipeline directly + batch = iter([batch]) + outputs = booster.execute_pipeline(batch, + model, + _criterion, + optimizer, + return_loss=True, + return_outputs=True) + # Backward and optimize + if booster.plugin.stage_manager.is_last_stage(): + loss = outputs['loss'] + pbar.set_postfix({'loss': loss.item()}) + else: + outputs = model(**batch) + loss = _criterion(outputs, None) + # Backward + booster.backward(loss, optimizer) + pbar.set_postfix({'loss': loss.item()}) - # Backward and optimize - booster.backward(loss, optimizer) optimizer.step() optimizer.zero_grad() lr_scheduler.step() - # Print log info - pbar.set_postfix({'loss': loss.item()}) - def main(): # ============================== @@ -107,7 +183,7 @@ def main(): '--plugin', type=str, default='torch_ddp', - choices=['torch_ddp', 'torch_ddp_fp16', 'gemini', 'low_level_zero'], + choices=['torch_ddp', 'torch_ddp_fp16', 'gemini', 'low_level_zero', 'hybrid_parallel'], help="plugin to use") parser.add_argument( "--model_type", @@ -116,6 +192,7 @@ def main(): help="bert or albert", ) parser.add_argument('--target_f1', type=float, default=None, help="target f1 score. Raise exception if not reached") + parser.add_argument('--use_lazy_init', type=bool, default=False, help="for initiating lazy init context") args = parser.parse_args() if args.model_type == 'bert': @@ -145,6 +222,17 @@ def main(): plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, initial_scale=2**5) elif args.plugin == 'low_level_zero': plugin = LowLevelZeroPlugin(initial_scale=2**5) + elif args.plugin == 'hybrid_parallel': + + # modify the param accordingly for finetuning test cases + plugin = HybridParallelPlugin(tp_size=1, + pp_size=2, + num_microbatches=None, + microbatch_size=1, + enable_all_optimization=True, + zero_stage=1, + precision='fp16', + initial_scale=1) booster = Booster(plugin=plugin, **booster_kwargs) @@ -165,8 +253,9 @@ def main(): # bert pretrained model cfg = AutoConfig.from_pretrained(model_name, num_labels=data_builder.num_labels) + if model_name == "bert-base-uncased": - model = BertForSequenceClassification.from_pretrained(model_name, config=cfg) + model = BertForSequenceClassification.from_pretrained(model_name, config=cfg).cuda() elif model_name == "albert-xxlarge-v2": model = AlbertForSequenceClassification.from_pretrained(model_name, config=cfg) else: @@ -196,19 +285,27 @@ def main(): num_training_steps=total_steps, ) + def _criterion(outputs, inputs): + outputs = output_transform_fn(outputs) + loss = criterion(outputs) + return loss + # ============================== # Boost with ColossalAI # ============================== - model, optimizer, _, _, lr_scheduler = booster.boost(model, optimizer, lr_scheduler=lr_scheduler) + model, optimizer, _criterion, _, lr_scheduler = booster.boost(model, + optimizer, + criterion=_criterion, + lr_scheduler=lr_scheduler) # ============================== # Train model # ============================== for epoch in range(NUM_EPOCHS): - train_epoch(epoch, model, optimizer, lr_scheduler, train_dataloader, booster, coordinator) + train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, train_dataloader, booster, coordinator) - results = evaluate_model(model, test_dataloader, data_builder.num_labels, args.task, data_builder.eval_splits, - coordinator) + results = evaluate_model(model, optimizer, _criterion, test_dataloader, data_builder.num_labels, args.task, + data_builder.eval_splits, booster, coordinator) if coordinator.is_master(): print(results) diff --git a/examples/language/bert/test_ci.sh b/examples/language/bert/test_ci.sh index 7fc6daabb..394ff831b 100755 --- a/examples/language/bert/test_ci.sh +++ b/examples/language/bert/test_ci.sh @@ -3,6 +3,6 @@ set -xe pip install -r requirements.txt -for plugin in "torch_ddp" "torch_ddp_fp16" "gemini" "low_level_zero"; do +for plugin in "torch_ddp" "torch_ddp_fp16" "gemini" "low_level_zero" "hybrid_parallel"; do torchrun --standalone --nproc_per_node 4 finetune.py --target_f1 0.86 --plugin $plugin --model_type "bert" done From e79b1e80e25a14c345a2702995b38e418d26c12a Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Mon, 4 Sep 2023 23:25:01 +0800 Subject: [PATCH 29/33] [checkpointio] support huggingface from_pretrained for all plugins (#4606) --- colossalai/booster/plugin/gemini_plugin.py | 2 + .../checkpoint_io/general_checkpoint_io.py | 2 + .../test_hybrid_huggingface_compatibility.py | 129 ------------------ .../test_plugins_huggingface_compatibility.py | 83 +++++++++++ 4 files changed, 87 insertions(+), 129 deletions(-) delete mode 100644 tests/test_checkpoint_io/test_hybrid_huggingface_compatibility.py create mode 100644 tests/test_checkpoint_io/test_plugins_huggingface_compatibility.py diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index 0f5ba6e9a..8489a8f29 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -18,6 +18,7 @@ from colossalai.checkpoint_io.utils import ( get_optimizer_base_filenames, get_shard_filename, load_shard_state_dict, + save_config_file, save_state_dict, save_state_dict_shards, ) @@ -111,6 +112,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO): if self.coordinator.is_master(): index_file.append_meta_data("total_size", total_size) index_file.write_index_file(save_index_file) + save_config_file(model.module, checkpoint_path) logging.info(f"The model is split into checkpoint shards. " f"You can find where each parameters has been saved in the " f"index located at {save_index_file}.") diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index 83e4bdcc8..09362d145 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -23,6 +23,7 @@ from .utils import ( load_state_dict, load_state_dict_into_model, load_states_into_optimizer, + save_config_file, save_param_groups, save_state_dict, save_state_dict_shards, @@ -185,6 +186,7 @@ class GeneralCheckpointIO(CheckpointIO): index_file.append_meta_data("total_size", total_size) index_file.write_index_file(save_index_file) + save_config_file(model, checkpoint_path, is_master=True) logging.info(f"The model is going to be split to checkpoint shards. " f"You can find where each parameters has been saved in the " f"index located at {save_index_file}.") diff --git a/tests/test_checkpoint_io/test_hybrid_huggingface_compatibility.py b/tests/test_checkpoint_io/test_hybrid_huggingface_compatibility.py deleted file mode 100644 index df907605d..000000000 --- a/tests/test_checkpoint_io/test_hybrid_huggingface_compatibility.py +++ /dev/null @@ -1,129 +0,0 @@ -import pytest -import torch -import torch.distributed as dist -from torch.optim import Adam -from utils import shared_tempdir - -import colossalai -from colossalai.booster import Booster -from colossalai.booster.plugin import HybridParallelPlugin -from colossalai.shardformer.layer.utils import Randomizer -from colossalai.tensor.d_tensor.api import clear_layout_converter -from colossalai.testing import ( - check_state_dict_equal, - clear_cache_before_run, - parameterize, - rerun_if_address_is_in_use, - spawn, -) -from tests.kit.model_zoo import model_zoo - - -def exam_from_pretrained(model_fn, - data_gen_fn, - output_transform_fn, - loss_fn, - test_config, - shard=True, - size_per_shard=32): - - def _criterion(outputs, inputs): - outputs = output_transform_fn(outputs) - loss = criterion(outputs) - return loss - - def _preprocess_data(data): - if booster.plugin.stage_manager is not None: - for k, v in data.items(): - if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__: - new_shape = [1] * v.dim() - new_shape[0] = 4 - data[k] = v.to('cuda').repeat(*new_shape) - return iter([data]) - else: - return {k: v.cuda() for k, v in data.items()} - - model = model_fn() - optimizer = Adam((model.parameters()), lr=0.001) - criterion = loss_fn - plugin = HybridParallelPlugin(**test_config) - booster = Booster(plugin=plugin) - - model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) - - data = data_gen_fn() - model.train() - if booster.plugin.stage_manager is not None: - booster.execute_pipeline(_preprocess_data(data), - model, - _criterion, - optimizer, - return_loss=True, - return_outputs=False) - else: - output = model(**_preprocess_data(data)) - loss = criterion(output) - optimizer.backward(loss) - - optimizer.step() - - with shared_tempdir() as tempdir: - - model_ckpt_path = f"{tempdir}/model" - booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard) - dist.barrier() - - new_model = model.unwrap().__class__.from_pretrained(model_ckpt_path) - new_optimizer = Adam(new_model.parameters(), lr=1e-3) - new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion) - - check_state_dict_equal(model.unwrap().state_dict(), new_model.unwrap().state_dict(), False) - - Randomizer.reset_index() - torch.cuda.empty_cache() - - -@clear_cache_before_run() -@parameterize('test_config', [{ - 'tp_size': 4, - 'pp_size': 1, - 'precision': 'fp32', -}, { - 'tp_size': 2, - 'pp_size': 2, - 'num_microbatches': 4, - 'precision': 'fp16', - 'initial_scale': 1 -}, { - 'tp_size': 2, - 'pp_size': 1, - 'zero_stage': 2, - 'precision': 'fp16', - 'initial_scale': 1 -}, { - 'tp_size': 1, - 'pp_size': 2, - 'num_microbatches': 4, - 'zero_stage': 1, - 'precision': 'fp16', - 'initial_scale': 1 -}]) -def run_test(test_config): - sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt') - for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - exam_from_pretrained(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) - clear_layout_converter() - torch.cuda.empty_cache() - - -def run_dist(rank, world_size, port): - config = {} - colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - run_test() - - -@pytest.mark.dist -@pytest.mark.parametrize('world_size', [4]) -@rerun_if_address_is_in_use() -def test_huggingface_compatibility(world_size): - spawn(run_dist, world_size) diff --git a/tests/test_checkpoint_io/test_plugins_huggingface_compatibility.py b/tests/test_checkpoint_io/test_plugins_huggingface_compatibility.py new file mode 100644 index 000000000..3f3b0392a --- /dev/null +++ b/tests/test_checkpoint_io/test_plugins_huggingface_compatibility.py @@ -0,0 +1,83 @@ +import os + +import pytest +import torch +import torch.distributed as dist +from utils import shared_tempdir + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin +from colossalai.nn.optimizer import HybridAdam +from colossalai.testing import ( + check_state_dict_equal, + clear_cache_before_run, + parameterize, + rerun_if_address_is_in_use, + spawn, +) +from tests.kit.model_zoo import model_zoo + + +@clear_cache_before_run() +@parameterize('model_name', ['transformers_gpt']) +@parameterize('plugin_type', ['ddp', 'zero', 'gemini']) +def exam_from_pretrained(plugin_type: str, model_name: str, shard=True, size_per_shard=32): + (model_fn, data_gen_fn, output_transform_fn, loss_fn, + _) = next(iter(model_zoo.get_sub_registry(model_name).values())) + criterion = loss_fn + + if plugin_type == 'ddp': + plugin = TorchDDPPlugin() + elif plugin_type == 'zero': + plugin = LowLevelZeroPlugin(stage=2, max_norm=1.0, initial_scale=32) + elif plugin_type == 'gemini': + plugin = GeminiPlugin(placement_policy='cuda', precision="fp16", initial_scale=32) + else: + raise ValueError(f"Plugin with type {plugin_type} is invalid, please check your argument.") + + booster = Booster(plugin=plugin) + + model = model_fn().cuda() + model_huggingface_cls = model.__class__ + optimizer = HybridAdam(model.parameters(), lr=0.001) + model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) + + data = data_gen_fn() + data = {k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items()} + output = model(**data) + loss = criterion(output) + + booster.backward(loss, optimizer) + optimizer.step() + + with shared_tempdir() as tempdir: + + model_ckpt_path = f"{tempdir}/model" + booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard) + dist.barrier() + + new_model = model_huggingface_cls.from_pretrained(model_ckpt_path) + new_model = new_model.cuda() + new_optimizer = HybridAdam(new_model.parameters(), lr=0.001) + new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion) + + if plugin_type == 'gemini': + check_state_dict_equal(model.unwrap().state_dict(only_rank_0=False), + new_model.unwrap().state_dict(only_rank_0=False), False) + else: + check_state_dict_equal(model.unwrap().state_dict(), new_model.unwrap().state_dict(), False) + dist.barrier() + + +def run_dist(rank, world_size, port): + config = {} + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + exam_from_pretrained() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [2]) +@rerun_if_address_is_in_use() +def test_huggingface_compatibility(world_size): + spawn(run_dist, world_size) From 86d22581e42b350fbe9c5a1f7bc45f7487620214 Mon Sep 17 00:00:00 2001 From: Bin Jia <45593998+FoolPlayer@users.noreply.github.com> Date: Tue, 5 Sep 2023 11:52:23 +0800 Subject: [PATCH 30/33] [shardformer] Add overlap optional for HybridParallelPlugin (#4615) * add optional overlap for plugin * remove fixed todo --- colossalai/booster/plugin/hybrid_parallel_plugin.py | 4 +++- colossalai/shardformer/layer/_operation.py | 2 -- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 8ad9b7956..d33e3485c 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -280,6 +280,7 @@ class HybridParallelPlugin(PipelinePluginBase): enable_flash_attention: bool = False, enable_jit_fused: bool = False, enable_sequence_parallelism: bool = False, + enable_sequence_overlap: bool = False, num_microbatches: Optional[int] = None, microbatch_size: Optional[int] = None, initial_scale: float = 2**16, @@ -341,7 +342,8 @@ class HybridParallelPlugin(PipelinePluginBase): enable_fused_normalization=self.enable_fused_normalization, enable_flash_attention=self.enable_flash_attention, enable_jit_fused=self.enable_jit_fused, - enable_sequence_parallelism=enable_sequence_parallelism) + enable_sequence_parallelism=enable_sequence_parallelism, + enable_sequence_overlap=enable_sequence_overlap) self.amp_config = dict( initial_scale=initial_scale, growth_factor=growth_factor, diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index f45ccc64b..45b305733 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -180,7 +180,6 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function): overlap = ctx.overlap if not overlap: - # TODO: overlap SP input with gradient computation input_parallel = _gather(input_, dim, process_group) total_input = input_parallel @@ -191,7 +190,6 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function): grad_output = grad_output.view(-1, grad_output.shape[-1]) total_input = total_input.view(-1, total_input.shape[-1]) - # TODO: overlap SP input with gradient computation if ctx.async_grad_reduce_scatter: # Asynchronous reduce-scatter input_list = [ From ec0866804c3f028f73d4e4d0bc1f3309362c4e89 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Tue, 5 Sep 2023 13:14:41 +0800 Subject: [PATCH 31/33] [shardformer] update shardformer readme (#4617) [shardformer] update shardformer readme [shardformer] update shardformer readme --- colossalai/shardformer/README.md | 11 ++++++----- examples/language/bert/README.md | 14 ++++++++------ 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/colossalai/shardformer/README.md b/colossalai/shardformer/README.md index 7dc15f0a0..2e48a79dc 100644 --- a/colossalai/shardformer/README.md +++ b/colossalai/shardformer/README.md @@ -429,12 +429,13 @@ As shown in the figures above, when the sequence length is around 1000 or greate ### Convergence -To validate that training the model using shardformers does not impact its convergence. We [fine-tuned the BERT model](./examples/convergence_benchmark.py) using both shardformer and non-shardformer approaches. We compared the accuracy, loss, F1 score of the training results. +To validate that training the model using shardformers does not impact its convergence. We [fine-tuned the BERT model](../../examples/language/bert/finetune.py) using both shardformer and non-shardformer approaches. The example that utilizes Shardformer simultaneously with Pipeline Parallelism and Data Parallelism (Zero1). We then compared the accuracy, loss, and F1 score of the training results. -| accuracy | f1 | loss | GPU number | model shard | + +| accuracy | f1 | loss | GPU number | model sharded | | :------: | :-----: | :-----: | :--------: | :---------: | -| 0.82594 | 0.87441 | 0.09913 | 4 | True | -| 0.81884 | 0.87299 | 0.10120 | 2 | True | -| 0.81855 | 0.87124 | 0.10357 | 1 | False | +| 0.84589 | 0.88613 | 0.43414 | 4 | True | +| 0.83594 | 0.88064 | 0.43298 | 1 | False | + Overall, the results demonstrate that using shardformers during model training does not affect the convergence. diff --git a/examples/language/bert/README.md b/examples/language/bert/README.md index da38e8375..6601edb79 100644 --- a/examples/language/bert/README.md +++ b/examples/language/bert/README.md @@ -7,13 +7,15 @@ This directory includes two parts: Using the Booster API finetune Huggingface Be bash test_ci.sh ``` -### Results on 2-GPU +### Bert-Finetune Results + +| Plugin | Accuracy | F1-score | GPU number | +| -------------- | -------- | -------- | -------- | +| torch_ddp | 84.4% | 88.6% | 2 | +| torch_ddp_fp16 | 84.7% | 88.8% | 2 | +| gemini | 84.0% | 88.4% | 2 | +| hybrid_parallel | 84.5% | 88.6% | 4 | -| Plugin | Accuracy | F1-score | -| -------------- | -------- | -------- | -| torch_ddp | 84.4% | 88.6% | -| torch_ddp_fp16 | 84.7% | 88.8% | -| gemini | 84.0% | 88.4% | ## Benchmark ``` From e71d2452936372eaca5d300d43a11c35958fc011 Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Tue, 5 Sep 2023 14:21:31 +0800 Subject: [PATCH 32/33] [test] ignore gpt2 shardformer test (#4619) --- tests/test_shardformer/test_model/test_shard_gpt2.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index a4def9e50..24f5137ae 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -102,6 +102,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, torch.cuda.empty_cache() +@pytest.mark.skip(reason="This test will hang in CI") @parameterize('test_config', [{ 'tp_size': 2, 'pp_size': 2, From bd18678478e5ecd18a9fa8a70eedea6f1fcdd036 Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Tue, 5 Sep 2023 16:02:23 +0800 Subject: [PATCH 33/33] [test] fix gemini checkpoint and gpt test (#4620) --- .../test_plugins_huggingface_compatibility.py | 2 +- tests/test_shardformer/test_model/test_shard_gpt2.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/test_checkpoint_io/test_plugins_huggingface_compatibility.py b/tests/test_checkpoint_io/test_plugins_huggingface_compatibility.py index 3f3b0392a..bd041a5e2 100644 --- a/tests/test_checkpoint_io/test_plugins_huggingface_compatibility.py +++ b/tests/test_checkpoint_io/test_plugins_huggingface_compatibility.py @@ -32,7 +32,7 @@ def exam_from_pretrained(plugin_type: str, model_name: str, shard=True, size_per elif plugin_type == 'zero': plugin = LowLevelZeroPlugin(stage=2, max_norm=1.0, initial_scale=32) elif plugin_type == 'gemini': - plugin = GeminiPlugin(placement_policy='cuda', precision="fp16", initial_scale=32) + plugin = GeminiPlugin(precision="fp16", initial_scale=32) else: raise ValueError(f"Plugin with type {plugin_type} is invalid, please check your argument.") diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index 24f5137ae..768063e53 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -102,7 +102,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, torch.cuda.empty_cache() -@pytest.mark.skip(reason="This test will hang in CI") @parameterize('test_config', [{ 'tp_size': 2, 'pp_size': 2, @@ -220,7 +219,7 @@ def check_gpt2_3d(rank, world_size, port): run_gpt2_3d_test() - +@pytest.mark.skip(reason="This test will hang in CI") @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run()