From 424629fea023a83aa84eacf55afc8007314d9f54 Mon Sep 17 00:00:00 2001 From: Bin Jia <45593998+FoolPlayer@users.noreply.github.com> Date: Wed, 16 Aug 2023 15:41:20 +0800 Subject: [PATCH] [shardformer/sequence parallel] Cherry pick commit to new branch (#4450) * [shardformer/sequence parallel] Support sequence parallel for gpt2 (#4384) * [sequence parallel] add sequence parallel linear col/row support (#4336) * add sequence parallel linear col/row support * add annotation * add annotation * add support for gpt2 fused qkv linear layer * support sequence parallel in GPT2 * add docstring and note * add requirments * remove unused flash-attb * modify flash attn test * modify flash attn setting * modify flash attn code * add assert before divide, rename forward function * [shardformer/test] fix gpt2 test with seq-parallel * [shardformer/sequence parallel] Overlap input gather and grad computation during col backward (#4401) * overlap gather input / grad computing during col backward * modify test for overlap * simplify code * fix code and modify cuda stream synchronize * [shardformer/sequence parallel] polish code --- .../booster/plugin/hybrid_parallel_plugin.py | 5 +- colossalai/shardformer/layer/_operation.py | 276 +++++++++++++++++- colossalai/shardformer/layer/linear.py | 23 +- .../shardformer/layer/qkv_fused_linear.py | 27 +- colossalai/shardformer/modeling/gpt2_seq.py | 222 ++++++++++++++ .../shardformer/policies/base_policy.py | 26 +- colossalai/shardformer/policies/gpt2.py | 9 + colossalai/shardformer/shard/shard_config.py | 1 + .../test_gpt2_qkv_fused_linear_1d.py | 34 ++- .../test_layer/test_linear_1d.py | 75 +++-- tests/test_shardformer/test_model/_utils.py | 15 +- .../test_model/test_shard_gpt2.py | 7 + 12 files changed, 655 insertions(+), 65 deletions(-) create mode 100644 colossalai/shardformer/modeling/gpt2_seq.py diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 28a19af0c..3d45a9112 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -152,6 +152,7 @@ class HybridParallelPlugin(PipelinePluginBase): enable_fused_normalization: bool = False, enable_flash_attention: bool = False, enable_jit_fused: bool = False, + enable_sequence_parallelism: bool = False, num_microbatches: Optional[int] = None, initial_scale: float = 2**16, min_scale: float = 1, @@ -178,6 +179,7 @@ class HybridParallelPlugin(PipelinePluginBase): self.enable_fused_normalization = enable_fused_normalization self.enable_flash_attention = enable_flash_attention self.enable_jit_fused = enable_jit_fused + self.enable_sequence_parallelism = enable_sequence_parallelism self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size) self.stage_manager = None self.schedule = None @@ -195,7 +197,8 @@ class HybridParallelPlugin(PipelinePluginBase): enable_all_optimization=self.enable_all_optimization, enable_fused_normalization=self.enable_fused_normalization, enable_flash_attention=self.enable_flash_attention, - enable_jit_fused=self.enable_jit_fused) + enable_jit_fused=self.enable_jit_fused, + enable_sequence_parallelism=enable_sequence_parallelism) self.amp_config = dict( initial_scale=initial_scale, growth_factor=growth_factor, diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index 7e97bee01..13e563123 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -1,3 +1,5 @@ +from typing import Any + import torch import torch.distributed as dist import torch.nn.functional as F @@ -141,6 +143,215 @@ class LinearWithAsyncCommunication(torch.autograd.Function): return grad_input, grad_weight, grad_bias, None, None, None +class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function): + """Gather input from sequence parallel in forward and reduce-scatter gradient in backward + + Args: + input_ (`torch.Tensor`): The input tensor from sequence parallel region. + process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication. + overlap (`bool`): Whther to overlap the all_gather op and gradient calculate in backward. + + """ + + @staticmethod + def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap): + ctx.save_for_backward(input_, weight) + ctx.use_bias = bias is not None + ctx.process_group = process_group + ctx.async_grad_reduce_scatter = async_grad_reduce_scatter + ctx.dim = dim + ctx.overlap = overlap + + input_parallel = _gather(input_, dim, process_group) + + if bias is not None: + output = F.linear(input_parallel, weight, bias) + else: + output = F.linear(input_parallel, weight) + + return output + + @staticmethod + def backward(ctx, grad_output): + input_, weight = ctx.saved_tensors + use_bias = ctx.use_bias + dim = ctx.dim + process_group = ctx.process_group + overlap = ctx.overlap + + if not overlap: + # TODO: overlap SP input with gradient computation + input_parallel = _gather(input_, dim, process_group) + + total_input = input_parallel + grad_input = grad_output.matmul(weight) + grad_output = grad_output.contiguous() + # Convert the tensor shapes to 2D for execution compatibility + if len(grad_output.shape) > 2: + grad_output = grad_output.view(-1, grad_output.shape[-1]) + total_input = total_input.view(-1, total_input.shape[-1]) + + # TODO: overlap SP input with gradient computation + if ctx.async_grad_reduce_scatter: + # Asynchronous reduce-scatter + input_list = [ + item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim) + ] + output = torch.empty(input_.shape, dtype=input_parallel.dtype, + device=input_parallel.device).contiguous() + handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True) + # Delay the start of weight gradient computation shortly (3us) to have + # reduce-scatter scheduled first and have GPU resources allocated + _ = torch.empty(1, device=grad_output.device) + 1 + + grad_weight = grad_output.t().matmul(total_input) + grad_bias = grad_output.sum(dim=0) if use_bias else None + + if ctx.async_grad_reduce_scatter: + handle.wait() + + else: + # create new stream for calculate the gradient + calculate_stream = torch.cuda.Stream() + + # do all gather in default stream + input_ = input_.contiguous() + world_size = dist.get_world_size(process_group) + rank = dist.get_rank(process_group) + tensor_list = [torch.empty_like(input_) for _ in range(world_size)] + tensor_list[rank] = input_ + gather_handle = dist.all_gather(tensor_list, input_, group=process_group, async_op=True) + + # calculate gradient in calculate_stream + with torch.cuda.stream(calculate_stream): + # calculate + grad_input = grad_output.matmul(weight) + grad_output = grad_output.contiguous() + # Convert the tensor shapes to 2D for execution compatibility + if len(grad_output.shape) > 2: + grad_output = grad_output.view(-1, grad_output.shape[-1]) + grad_bias = grad_output.sum(dim=0) if use_bias else None + + # prepare data + input_list = [ + item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim) + ] + output = torch.empty(input_.shape, dtype=input_.dtype, device=input_.device).contiguous() + + torch.cuda.current_stream().wait_stream(calculate_stream) + + reducescatter_handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True) + with torch.cuda.stream(calculate_stream): + input_parallel = torch.cat(tensor_list, dim=dim).contiguous() + if len(input_parallel.shape) > 2: + input_parallel = input_parallel.view(-1, input_parallel.shape[-1]) + print(grad_output.shape, input_parallel.shape) + grad_weight = grad_output.t().matmul(input_parallel) + + torch.cuda.current_stream().wait_stream(calculate_stream) + + return output, grad_weight, grad_bias, None, None, None, None + + +class _LinearWithReduceScatterForwardGatherBackward(torch.autograd.Function): + """Gather input from sequence parallel in forward and reduce-scatter gradient in backward + + Args: + input_ (`torch.Tensor`): The input tensor from sequence parallel region. + process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication. + + """ + + @staticmethod + def forward(ctx, input_, process_group, dim): + ctx.dim = dim + ctx.process_group = process_group + + # do reduce-scatter + new_shape = list(input_.shape) + assert new_shape[dim] % dist.get_world_size(process_group) == 0, \ + f'The dimension to split ({new_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). ' + new_shape[dim] = new_shape[dim] // dist.get_world_size(process_group) + input_list = [item.contiguous() for item in torch.chunk(input_, dist.get_world_size(process_group), dim=dim)] + output = torch.empty(new_shape, dtype=input_.dtype, device=input_.device) + dist.reduce_scatter(output, input_list, group=process_group) + + return output + + @staticmethod + def backward(ctx, grad_output): + dim = ctx.dim + process_group = ctx.process_group + + return _gather(grad_output, dim, process_group), None, None + + +class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function): + """ + This class is designed for matmul operation with gather forward and reduce-scatter backward. + + Args: + input_ (`torch.Tensor`): input matrix. + dim (int): the dimension to perform split and gather + process_group (`torch.distributed.ProcessGroup`): the process group used for collective communication + + """ + + @staticmethod + def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim): + ctx.save_for_backward(input_, weight) + ctx.use_bias = bias is not None + ctx.process_group = process_group + ctx.async_grad_reduce_scatter = async_grad_reduce_scatter + ctx.dim = dim + + input_parallel = _gather(input_, dim, process_group) + + output = torch.matmul(input_parallel, weight) + + if bias is not None: + output = output + bias + return output + + @staticmethod + def backward(ctx, grad_output): + input_, weight = ctx.saved_tensors + use_bias = ctx.use_bias + dim = ctx.dim + process_group = ctx.process_group + + # TODO: overlap SP input with gradient computation + input_parallel = _gather(input_, dim, process_group) + + total_input = input_parallel + grad_input = grad_output.matmul(weight.T) + grad_output = grad_output.contiguous() + # Convert the tensor shapes to 2D for execution compatibility + if len(grad_output.shape) > 2: + grad_output = grad_output.view(-1, grad_output.shape[-1]) + total_input = total_input.view(-1, total_input.shape[-1]) + + # TODO: overlap SP input with gradient computation + if ctx.async_grad_reduce_scatter: + # Asynchronous reduce-scatter + input_list = [ + item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim) + ] + output = torch.empty(input_.shape, dtype=input_parallel.dtype, device=input_parallel.device).contiguous() + handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True) + # Delay the start of weight gradient computation shortly (3us) to have + # reduce-scatter scheduled first and have GPU resources allocated + _ = torch.empty(1, device=grad_output.device) + 1 + + grad_weight = total_input.t().matmul(grad_output) + grad_bias = grad_output.sum(dim=0) if use_bias else None + + if ctx.async_grad_reduce_scatter: + handle.wait() + + return output, grad_weight, grad_bias, None, None, None + + class _SplitForwardGatherBackward(torch.autograd.Function): """ Split the input and keep only the corresponding chuck to the rank. @@ -200,6 +411,26 @@ class _ReduceBackward(torch.autograd.Function): return _reduce(grad_output, ctx.process_group), None +class _GatherForwardSplitBackward(torch.autograd.Function): + """Gather the input from model parallel region and concatenate. + + Args: + input_: input matrix. + parallel_mode: parallel mode. + dim: dimension + """ + + @staticmethod + def forward(ctx, input_, dim, process_group): + ctx.process_group = process_group + ctx.dim = dim + return _gather(input_, dim, process_group) + + @staticmethod + def backward(ctx, grad_output): + return _split(grad_output, ctx.dim, ctx.process_group), None, None + + def _reduce(input_, process_group): # skip if only one rank involved if dist.get_world_size(process_group) == 1: @@ -235,6 +466,7 @@ def _gather(input_, dim=-1, process_group=None): return input_ # all gather + input_ = input_.contiguous() rank = dist.get_rank(process_group) tensor_list = [torch.empty_like(input_) for _ in range(world_size)] tensor_list[rank] = input_ @@ -246,24 +478,27 @@ def _gather(input_, dim=-1, process_group=None): return output -class _GatherForwardSplitBackward(torch.autograd.Function): - """Gather the input from model parallel region and concatenate. +def _reduce_scatter(input_, dim=1, process_group=None): + """ Do reduce-scatter operation. Args: - input_: input matrix. - parallel_mode: parallel mode. - dim: dimension + input_ (`torch.Tensor`): The input tensor from sequence parallel region. + dim (int): The dimension to perform reduce-scatter. + process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication. """ + world_size = dist.get_world_size(process_group) + if world_size == 1: + return input_ - @staticmethod - def forward(ctx, input_, dim, process_group): - ctx.process_group = process_group - ctx.dim = dim - return _gather(input_, dim, process_group) + # reduce-scatter + new_shape = list(input_.shape) + assert new_shape[dim] % dist.get_world_size(process_group) == 0, \ + f'The dimension to split ({new_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). ' + new_shape[dim] = new_shape[dim] // world_size + output = torch.empty(new_shape, dtype=input_.dtype, device=input_.device) + dist.reduce_scatter(output, input_, group=process_group) - @staticmethod - def backward(ctx, grad_output): - return _split(grad_output, ctx.dim, ctx.process_group), None, None + return output def matmul_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce): @@ -274,6 +509,21 @@ def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allre return LinearWithAsyncCommunication.apply(input_, weight, bias, process_group, async_grad_allreduce) +def linear_gather_forward_reducescatter_backward(input_, weight, bias, process_group, async_grad_reduce_scatter, dim, + overlap): + return _LinearWithGatherForwardReduceScatterBackward.apply(input_, weight, bias, process_group, + async_grad_reduce_scatter, dim, overlap) + + +def linear_reducescatter_forward_gather_backward(input_, process_group, dim): + return _LinearWithReduceScatterForwardGatherBackward.apply(input_, process_group, dim) + + +def matmul_gather_forward_reducescatter_backward(input_, weight, bias, process_group, async_grad_reduce_scatter, dim): + return _MatmulWithGatherForwardReduceScatterBackward.apply(input_, weight, bias, process_group, + async_grad_reduce_scatter, dim) + + def gather_forward_split_backward(input_, dim, process_group): return _GatherForwardSplitBackward.apply(input_, dim, process_group) diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py index d59b68ce4..69ac3ad25 100644 --- a/colossalai/shardformer/layer/linear.py +++ b/colossalai/shardformer/layer/linear.py @@ -24,6 +24,8 @@ from colossalai.tensor.d_tensor.api import ( from ._operation import ( gather_forward_split_backward, + linear_gather_forward_reducescatter_backward, + linear_reducescatter_forward_gather_backward, linear_with_async_comm, reduce_forward, split_forward_gather_backward, @@ -50,6 +52,8 @@ class Linear1D_Col(ParallelModule): gather_output (bool, optional): If true, call all-gather on output and make Y available to all GPUs, otherwise, every GPU will have its output which is :math:`Y_i = XA_i`, defaults to False + seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False. + overlap (`bool`): If set to ``True``, it will overlap input all-gather with gradient computation during backward, defaults to False. skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, which is preserved for kernel fusion, defaults to False weight_initializer (`typing.Callable`): @@ -69,6 +73,8 @@ class Linear1D_Col(ParallelModule): device: torch.device = None, process_group: ProcessGroup = None, gather_output: bool = False, + seq_parallel: bool = False, + overlap: bool = False, skip_bias_add: bool = False, weight: Optional[Parameter] = None, bias_: Optional[Parameter] = None, @@ -80,6 +86,8 @@ class Linear1D_Col(ParallelModule): self.in_features = in_features self.out_features = out_features self.gather_output = gather_output + self.seq_parallel = seq_parallel + self.overlap = overlap self.skip_bias_add = skip_bias_add self.device = device self.process_group = process_group @@ -180,7 +188,11 @@ class Linear1D_Col(ParallelModule): # Matrix multiply. bias = self.bias if not self.skip_bias_add else None - output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True) + if self.seq_parallel: + output_parallel = linear_gather_forward_reducescatter_backward(input_parallel, self.weight, bias, + self.process_group, True, 1, self.overlap) + else: + output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True) if self.gather_output: # All-gather across the partitions. @@ -203,6 +215,8 @@ class Linear1D_Row(ParallelModule): bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. dtype (`torch.dtype`): The dtype of parameters, defaults to None. parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False. + process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None. + seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False. skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, which is preserved for kernel fusion, defaults to False weight_initializer (:class:`typing.Callable`, optional): @@ -221,6 +235,7 @@ class Linear1D_Row(ParallelModule): dtype: torch.dtype = None, device: torch.device = None, process_group: ProcessGroup = None, + seq_parallel: bool = False, parallel_input: bool = True, skip_bias_add: bool = False, weight: Optional[Parameter] = None, @@ -238,6 +253,7 @@ class Linear1D_Row(ParallelModule): self.parallel_input = parallel_input self.skip_bias_add = skip_bias_add self.process_group = process_group + self.seq_parallel = seq_parallel self.num_partitions = dist.get_world_size(self.process_group) if skip_bias_add and not bias: @@ -373,7 +389,10 @@ class Linear1D_Row(ParallelModule): output = torch.cat(output_parallel_list, dim=-1) else: output_parallel = F.linear(input_, self.weight) - output = reduce_forward(output_parallel, self.process_group) + if self.seq_parallel: + output = linear_reducescatter_forward_gather_backward(output_parallel, self.process_group, 1) + else: + output = reduce_forward(output_parallel, self.process_group) if not self.skip_bias_add: if self.bias is not None: diff --git a/colossalai/shardformer/layer/qkv_fused_linear.py b/colossalai/shardformer/layer/qkv_fused_linear.py index df942d43e..ccb2bf7ea 100644 --- a/colossalai/shardformer/layer/qkv_fused_linear.py +++ b/colossalai/shardformer/layer/qkv_fused_linear.py @@ -25,7 +25,9 @@ from colossalai.tensor.d_tensor.api import ( from ._operation import ( gather_forward_split_backward, + linear_reducescatter_forward_gather_backward, linear_with_async_comm, + matmul_gather_forward_reducescatter_backward, matmul_with_async_comm, reduce_backward, reduce_forward, @@ -150,6 +152,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule): device (`torch.device`): The device of parameters, defaults to None. n_fused (int): The number items fused, defaults to 3 (QKV). process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None. + seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False. gather_output (bool, optional): If true, call all-gather on output and make Y available to all GPUs, otherwise, every GPU will have its output which is :math:`Y_i = XA_i`, defaults to False @@ -173,6 +176,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule): process_group: ProcessGroup = None, async_communication: bool = False, gather_output: bool = False, + seq_parallel: bool = False, skip_bias_add: bool = False, n_fused: int = 3, weight: Optional[Parameter] = None, @@ -185,6 +189,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule): self.in_features = in_features self.out_features = out_features self.gather_output = gather_output + self.seq_parallel = seq_parallel self.skip_bias_add = skip_bias_add self.device = device self.n_fused = n_fused @@ -296,15 +301,19 @@ class GPT2FusedLinearConv1D_Col(ParallelModule): assert input_.shape[-1] == self.weight.shape[0], \ 'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format( input_.shape, self.weight.shape, self.weight.shape[-1]) - # Set up backprop all-reduce. - input_parallel = reduce_backward(input_, self.process_group) - # input_parallel = input_ # Matrix multiply. bias = self.bias if not self.skip_bias_add else None - output_parallel = matmul_with_async_comm(input_parallel, self.weight, bias, self.process_group, - self.async_communication) + if self.seq_parallel: + input_parallel = input_ + output_parallel = matmul_gather_forward_reducescatter_backward(input_parallel, self.weight, bias, + self.process_group, True, 1) + else: + # Set up backprop all-reduce. + input_parallel = reduce_backward(input_, self.process_group) + output_parallel = matmul_with_async_comm(input_parallel, self.weight, bias, self.process_group, + self.async_communication) if self.gather_output: # All-gather across the partitions. @@ -329,6 +338,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule): dtype (`torch.dtype`): The dtype of parameters, defaults to None. parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False. skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, + seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False. which is preserved for kernel fusion, defaults to False weight_initializer (:class:`typing.Callable`, optional): The initializer of weight, defaults to kaiming uniform initializer. @@ -346,6 +356,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule): dtype: torch.dtype = None, device: torch.device = None, process_group: ProcessGroup = None, + seq_parallel: bool = False, parallel_input: bool = True, skip_bias_add: bool = False, weight: Optional[Parameter] = None, @@ -363,6 +374,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule): self.parallel_input = parallel_input self.skip_bias_add = skip_bias_add self.process_group = process_group + self.seq_parallel = seq_parallel self.num_partitions = dist.get_world_size(self.process_group) if skip_bias_add and not bias: @@ -499,7 +511,10 @@ class GPT2FusedLinearConv1D_Row(ParallelModule): output = torch.cat(output_parallel_list, dim=-1) else: output_parallel = torch.matmul(input_, self.weight) - output = reduce_forward(output_parallel, self.process_group) + if self.seq_parallel: + output = linear_reducescatter_forward_gather_backward(output_parallel, self.process_group, 1) + else: + output = reduce_forward(output_parallel, self.process_group) if not self.skip_bias_add: if self.bias is not None: diff --git a/colossalai/shardformer/modeling/gpt2_seq.py b/colossalai/shardformer/modeling/gpt2_seq.py new file mode 100644 index 000000000..a6da96e7b --- /dev/null +++ b/colossalai/shardformer/modeling/gpt2_seq.py @@ -0,0 +1,222 @@ +# this code is modified from transformers.models.gpt2.modeling_gpt2 +# https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/models/gpt2/modeling_gpt2.py#L670 + +from typing import Optional, Tuple, Union + +import torch +import torch.distributed as dist +from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions +from transformers.utils import logging + +from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward +from colossalai.shardformer.shard import ShardConfig + +logger = logging.get_logger(__name__) + + +# TODO: put all contents in `gpt2.py` and make it compatible with pipeline +def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig): + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + batch_size = input_ids.shape[0] + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size = inputs_embeds.shape[0] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, input_shape[-1]) + if position_ids is not None: + position_ids = position_ids.view(-1, input_shape[-1]) + + if past_key_values is None: + past_length = 0 + past_key_values = tuple([None] * len(self.h)) + else: + past_length = past_key_values[0][0].size(-2) + if position_ids is None: + position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) + + # GPT2Attention mask. + if attention_mask is not None: + if batch_size <= 0: + raise ValueError("batch_size has to be defined and > 0") + attention_mask = attention_mask.view(batch_size, -1) + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + attention_mask = attention_mask[:, None, None, :] + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and the dtype's smallest value for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.add_cross_attention and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # head_mask has shape n_layer x batch x n_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) + position_embeds = self.wpe(position_ids) + hidden_states = inputs_embeds + position_embeds + + if token_type_ids is not None: + token_type_embeds = self.wte(token_type_ids) + hidden_states = hidden_states + token_type_embeds + + hidden_states = self.drop(hidden_states) + + output_shape = input_shape + (hidden_states.size(-1),) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + all_hidden_states = () if output_hidden_states else None + + # split the input tensor along sequence dimension + # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] + hidden_states = split_forward_gather_backward(hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group) + + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + # Model parallel + if self.model_parallel: + torch.cuda.set_device(hidden_states.device) + # Ensure layer_past is on same device as hidden_states (might not be correct) + if layer_past is not None: + layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) + # Ensure that attention_mask is always on the same device as hidden_states + if attention_mask is not None: + attention_mask = attention_mask.to(hidden_states.device) + if isinstance(head_mask, torch.Tensor): + head_mask = head_mask.to(hidden_states.device) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, use_cache, output_attentions) + + return custom_forward + + outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + None, + attention_mask, + head_mask[i], + encoder_hidden_states, + encoder_attention_mask, + ) + else: + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask[i], + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) + + # Model Parallel: If it's the last layer for that device, put things on the next device + if self.model_parallel: + for k, v in self.device_map.items(): + if i == v[-1] and "cuda:" + str(k) != self.last_device: + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + + # When sequence parallelism done, gather the output tensor in forward and split it in backward + hidden_states = gather_forward_split_backward(hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group) + + hidden_states = self.ln_f(hidden_states) + hidden_states = hidden_states.view(output_shape) + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] + if v is not None) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + return forward diff --git a/colossalai/shardformer/policies/base_policy.py b/colossalai/shardformer/policies/base_policy.py index 69493bfb6..7022a1cfd 100644 --- a/colossalai/shardformer/policies/base_policy.py +++ b/colossalai/shardformer/policies/base_policy.py @@ -11,17 +11,12 @@ from torch.nn import Module from colossalai.pipeline.stage_manager import PipelineStageManager +from ..layer.parallel_module import ParallelModule from ..shard.shard_config import ShardConfig __all__ = ["ParallelModule", "SubModuleReplacementDescription", "ModulePolicyDescription", "Policy"] -class ParallelModule(): - - def __init__(self): - pass - - @dataclass class SubModuleReplacementDescription: r""" @@ -231,3 +226,22 @@ class Policy(ABC): end_idx = num_layers_per_stage_accumulated[stage + 1] return [start_idx, end_idx] + + def append_seq_parallel_to_policy( + self, + suffix_list: List[str], + module_policy_description: ModulePolicyDescription, + ): + r""" + Append the sequence parallel policy to the policy for the given key. + + Args: + suffix_list (List[str]): the suffix list of the module to be parallelized + policy (Dict[Union[str, nn.Module], ModulePolicyDescription]): the policy to be updated + """ + + for sub_description in module_policy_description.sub_module_replacement: + if (sub_description.suffix in suffix_list): + if sub_description.kwargs is None: + sub_description.kwargs = {} + sub_description.kwargs["seq_parallel"] = True diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 20e5fa372..276d95660 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -7,6 +7,7 @@ import colossalai.shardformer.layer as col_nn from .._utils import getattr_, setattr_ from ..modeling.gpt2 import GPT2PipelineForwards, get_gpt2_flash_attention_forward +from ..modeling.gpt2_seq import gpt2_sequence_parallel_forward_fn from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = [ @@ -49,6 +50,9 @@ class GPT2Policy(Policy): target_module=col_nn.DropoutForParallelInput, ), ]) + if self.shard_config.enable_sequence_parallelism: + policy[GPT2Model].method_replacement = {"forward": gpt2_sequence_parallel_forward_fn(self.shard_config)} + policy[GPT2Block] = ModulePolicyDescription(attribute_replacement={ "attn.embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, "attn.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, @@ -120,6 +124,11 @@ class GPT2Policy(Policy): policy[GPT2Attention] = ModulePolicyDescription(method_replacement={ 'forward': get_gpt2_flash_attention_forward(), }) + + if self.shard_config.enable_sequence_parallelism: + suffix_list = ["attn.c_attn", "attn.c_proj", "mlp.c_fc", "mlp.c_proj"] + self.append_seq_parallel_to_policy(suffix_list=suffix_list, module_policy_description=policy[GPT2Block]) + return policy def postprocess(self): diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 0c28f115d..a36e878c6 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -28,6 +28,7 @@ class ShardConfig: enable_all_optimization: bool = False enable_flash_attention: bool = False enable_jit_fused: bool = False + enable_sequence_parallelism: bool = False # pipeline_parallel_size: int # data_parallel_size: int diff --git a/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py b/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py index b45cd172c..ae6a1dc90 100644 --- a/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py +++ b/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py @@ -53,8 +53,7 @@ def rearrange(tensor: torch.Tensor, dim: int): return rearanged_tensor -@parameterize('lazy_init', [False, True]) -def check_linear_conv_1d_col(lazy_init: bool): +def check_linear_conv_1d_col(lazy_init: bool, seq_parallel: bool): ctx = LazyInitContext() if lazy_init else nullcontext() linear = Conv1D(192, 48).cuda() with ctx: @@ -62,6 +61,7 @@ def check_linear_conv_1d_col(lazy_init: bool): linear_conv_col = GPT2FusedLinearConv1D_Col.from_native_module(linear_copy, process_group=None, gather_output=True, + seq_parallel=seq_parallel, n_fused=3) assert linear.weight.shape == torch.Size([48, 192]) @@ -76,10 +76,11 @@ def check_linear_conv_1d_col(lazy_init: bool): linear.load_state_dict(linear_conv_col.state_dict()) # check computation correctness - x = torch.rand(4, 48).cuda() + x = torch.rand(1, 4, 48).cuda() out = linear(x) - gather_out = linear_conv_col(x) - assert_close(rearrange(out, 1), gather_out) + x_for_shard = x.expand_as(x.clone()) if seq_parallel is False else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()] + gather_out = linear_conv_col(x_for_shard) + assert_close(rearrange(out, -1), gather_out) # check backward correctness out.sum().backward() @@ -89,14 +90,16 @@ def check_linear_conv_1d_col(lazy_init: bool): assert_close(target_grad, linear_conv_col.weight.grad) -@parameterize('lazy_init', [False, True]) -def check_linear_conv_1d_row(lazy_init: bool): +def check_linear_conv_1d_row(lazy_init: bool, seq_parallel: bool): ctx = LazyInitContext() if lazy_init else nullcontext() linear = Conv1D(192, 48).cuda() with ctx: linear_copy = Conv1D(192, 48).cuda() - linear_row = GPT2FusedLinearConv1D_Row.from_native_module(linear_copy, process_group=None, parallel_input=False) + linear_row = GPT2FusedLinearConv1D_Row.from_native_module(linear_copy, + process_group=None, + parallel_input=False, + seq_parallel=seq_parallel) assert linear.weight.shape == torch.Size([48, 192]) assert linear_row.weight.shape == torch.Size([24, 192]) @@ -109,10 +112,11 @@ def check_linear_conv_1d_row(lazy_init: bool): linear.load_state_dict(linear_row.state_dict()) # check computation correctness - x = torch.rand(4, 48).cuda() + x = torch.rand(1, 4, 48).cuda() out = linear(x) gather_out = linear_row(x) - assert_close(out, gather_out) + target_out = out if seq_parallel is False else torch.chunk(out.clone(), 2, dim=1)[dist.get_rank()] + assert_close(target_out, gather_out) # check backward correctness out.sum().backward() @@ -123,12 +127,18 @@ def check_linear_conv_1d_row(lazy_init: bool): assert_close(target_grad, linear_row.weight.grad) +@parameterize('lazy_init', [False, True]) +@parameterize('seq_parallel', [False, True]) +def check_gpt2_qkv_fused_linear_1d(lazy_init: bool, seq_parallel: bool): + check_linear_conv_1d_col(lazy_init, seq_parallel) + check_linear_conv_1d_row(lazy_init, seq_parallel) + + def run_dist(rank, world_size, port): colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') # test for linear conv - check_linear_conv_1d_col() - check_linear_conv_1d_row() + check_gpt2_qkv_fused_linear_1d() @rerun_if_address_is_in_use() diff --git a/tests/test_shardformer/test_layer/test_linear_1d.py b/tests/test_shardformer/test_layer/test_linear_1d.py index aa75879e0..3ad8f14b9 100644 --- a/tests/test_shardformer/test_layer/test_linear_1d.py +++ b/tests/test_shardformer/test_layer/test_linear_1d.py @@ -12,13 +12,16 @@ from colossalai.tensor.d_tensor import is_distributed_tensor from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -@parameterize('lazy_init', [False, True]) -def check_linear_1d_col(lazy_init: bool): +def check_linear_1d_col(lazy_init: bool, seq_parallel: bool, overlap: bool): ctx = LazyInitContext() if lazy_init else nullcontext() linear = nn.Linear(32, 128).cuda() with ctx: linear_copy = nn.Linear(32, 128).cuda() - linear_col = Linear1D_Col.from_native_module(linear_copy, process_group=None, gather_output=True) + linear_col = Linear1D_Col.from_native_module(linear_copy, + process_group=None, + gather_output=True, + seq_parallel=seq_parallel, + overlap=overlap) # ensure that the parameters are distributed assert is_distributed_tensor(linear_col.weight) @@ -35,10 +38,11 @@ def check_linear_1d_col(lazy_init: bool): linear_col.load_state_dict(linear.state_dict()) # check computation correctness - x = torch.rand(4, 32).cuda() + # [batch_size, seq_len, hidden_size] + x = torch.rand(2, 4, 32).cuda() x_for_unshard = x.expand_as(x.clone()) x_for_unshard.requires_grad_(True) - x_for_shard = x.expand_as(x.clone()) + x_for_shard = x.expand_as(x.clone()) if seq_parallel is False else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()] x_for_shard.requires_grad_(True) out = linear(x_for_unshard) @@ -56,17 +60,21 @@ def check_linear_1d_col(lazy_init: bool): # check the input gradients assert x_for_shard.grad is not None assert x_for_unshard.grad is not None - assert_close(x_for_unshard.grad, x_for_shard.grad) + target_unshard_gard = x_for_unshard.grad if seq_parallel is False else torch.chunk( + x_for_unshard.grad.clone(), 2, dim=1)[dist.get_rank()] + assert_close(target_unshard_gard, x_for_shard.grad) -@parameterize('lazy_init', [False, True]) -def check_linear_1d_row(lazy_init: bool): +def check_linear_1d_row(lazy_init: bool, seq_parallel: bool): ctx = LazyInitContext() if lazy_init else nullcontext() linear = nn.Linear(32, 128).cuda() with ctx: linear_copy = nn.Linear(32, 128).cuda() - linear_row = Linear1D_Row.from_native_module(linear_copy, process_group=None, parallel_input=False) + linear_row = Linear1D_Row.from_native_module(linear_copy, + process_group=None, + parallel_input=False, + seq_parallel=seq_parallel) assert linear_row.weight.shape == torch.Size([128, 16]) assert linear_row.bias.shape == torch.Size([128]) @@ -77,7 +85,8 @@ def check_linear_1d_row(lazy_init: bool): linear_row.load_state_dict(linear.state_dict()) # check computation correctness - x = torch.rand(4, 32).cuda() + # [batch_size, seq_len, hidden_size] + x = torch.rand(2, 4, 32).cuda() x_for_unshard = x.expand_as(x.clone()) x_for_unshard.requires_grad_(True) x_for_shard = x.expand_as(x.clone()) @@ -86,7 +95,8 @@ def check_linear_1d_row(lazy_init: bool): # run forward out = linear(x_for_unshard) gather_out = linear_row(x_for_shard) - assert_close(out, gather_out) + target_out = out if seq_parallel is False else torch.chunk(out.clone(), 2, dim=1)[dist.get_rank()] + assert_close(target_out, gather_out) # check backward correctness out.sum().backward() @@ -102,8 +112,7 @@ def check_linear_1d_row(lazy_init: bool): assert_close(x_for_unshard.grad, x_for_shard.grad) -@parameterize('lazy_init', [False, True]) -def check_linear_col_plus_row(lazy_init: bool): +def check_linear_col_plus_row(lazy_init: bool, seq_parallel: bool, overlap: bool): ctx = LazyInitContext() if lazy_init else nullcontext() linear_1 = nn.Linear(32, 128).cuda() @@ -112,8 +121,15 @@ def check_linear_col_plus_row(lazy_init: bool): with ctx: linear_1_copy = nn.Linear(32, 128).cuda() linear_2_copy = nn.Linear(128, 32).cuda() - linear_col = Linear1D_Col.from_native_module(linear_1_copy, process_group=None, gather_output=False) - linear_row = Linear1D_Row.from_native_module(linear_2_copy, process_group=None, parallel_input=True) + linear_col = Linear1D_Col.from_native_module(linear_1_copy, + process_group=None, + gather_output=False, + seq_parallel=seq_parallel, + overlap=overlap) + linear_row = Linear1D_Row.from_native_module(linear_2_copy, + process_group=None, + parallel_input=True, + seq_parallel=seq_parallel) linear_1.load_state_dict(linear_col.state_dict()) linear_col.load_state_dict(linear_1.state_dict()) @@ -121,16 +137,18 @@ def check_linear_col_plus_row(lazy_init: bool): linear_row.load_state_dict(linear_2.state_dict()) # check computation correctness - x = torch.rand(4, 32).cuda() + # [batch_size, seq_len, hidden_size] + x = torch.rand(2, 4, 32).cuda() x_for_unshard = x.expand_as(x.clone()) x_for_unshard.requires_grad_(True) - x_for_shard = x.expand_as(x.clone()) + x_for_shard = x.expand_as(x.clone()) if seq_parallel is False else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()] x_for_shard.requires_grad_(True) # run forward unshard_out = linear_2(linear_1(x_for_unshard)) shard_out = linear_row(linear_col(x_for_shard)) - assert_close(unshard_out, shard_out) + target_out = unshard_out if seq_parallel is False else torch.chunk(unshard_out.clone(), 2, dim=1)[dist.get_rank()] + assert_close(target_out, shard_out) # check backward correctness unshard_out.sum().backward() @@ -143,19 +161,28 @@ def check_linear_col_plus_row(lazy_init: bool): # check the input gradients assert x_for_shard.grad is not None assert x_for_unshard.grad is not None - assert_close(x_for_unshard.grad, x_for_shard.grad) + target_unshard_gard = x_for_unshard.grad if seq_parallel is False else torch.chunk( + x_for_unshard.grad.clone(), 2, dim=1)[dist.get_rank()] + assert_close(target_unshard_gard, x_for_shard.grad) -def run_dist(rank, world_size, port): +@parameterize('lazy_init', [False, True]) +@parameterize('seq_parallel', [False, True]) +@parameterize('overlap', [False, True]) +def run_dist_linear_test(lazy_init, seq_parallel, overlap): + check_linear_1d_col(lazy_init, seq_parallel, overlap) + check_linear_1d_row(lazy_init, seq_parallel) + check_linear_col_plus_row(lazy_init, seq_parallel, overlap) + + +def check_dist_linear(rank, world_size, port): colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - check_linear_1d_col() - check_linear_1d_row() - check_linear_col_plus_row() + run_dist_linear_test() @rerun_if_address_is_in_use() def test_linear(): - spawn(run_dist, nprocs=2) + spawn(check_dist_linear, nprocs=2) if __name__ == '__main__': diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index 921af2a8b..7e1e6f2fe 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -1,4 +1,5 @@ import copy +import math from contextlib import nullcontext from typing import Any, Callable, Dict, List, Optional @@ -25,6 +26,7 @@ def build_model(model_fn, enable_tensor_parallelism=True, enable_flash_attention=False, enable_jit_fused=False, + enable_sequence_parallelism=False, use_lazy_init: bool = False): # create new model ctx = LazyInitContext() if use_lazy_init else nullcontext() @@ -38,7 +40,8 @@ def build_model(model_fn, shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization, enable_tensor_parallelism=enable_tensor_parallelism, enable_flash_attention=enable_flash_attention, - enable_jit_fused=enable_jit_fused) + enable_jit_fused=enable_jit_fused, + enable_sequence_parallelism=enable_sequence_parallelism) model_copy = copy.deepcopy(org_model) shard_former = ShardFormer(shard_config=shard_config) sharded_model, shared_params = shard_former.optimize(model_copy) @@ -135,6 +138,16 @@ def run_forward_backward_with_hybrid_plugin(org_model: Module, sharded_model: Mo return loss data = data_gen_fn() + + if booster.plugin.enable_sequence_parallelism and booster.plugin.tp_size != 0: + seq_len = data['input_ids'].shape[1] + lcm = booster.plugin.tp_size * seq_len // math.gcd(booster.plugin.tp_size, seq_len) + times = lcm // seq_len + input_shape = data['input_ids'].shape + for k, v in data.items(): + if v.shape == input_shape: + data[k] = v.repeat(1, times) + sharded_model.train() if booster.plugin.stage_manager is not None: for k, v in data.items(): diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index ca086bf12..c97702cbb 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -106,6 +106,13 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'enable_all_optimization': True, 'use_lazy_init': False, 'precision': 'fp32', +}, { + 'tp_size': 4, + 'pp_size': 1, + 'enable_all_optimization': False, + 'use_lazy_init': True, + 'enable_sequence_parallelism': True, + 'precision': 'fp32', }]) @clear_cache_before_run() def run_gpt2_test(test_config):