[shardformer/sequence parallel] Cherry pick commit to new branch (#4450)

* [shardformer/sequence parallel] Support sequence parallel for gpt2 (#4384)

* [sequence parallel] add sequence parallel linear col/row support (#4336)

* add sequence parallel linear col/row support

* add annotation

* add annotation

* add support for gpt2 fused qkv linear layer

* support sequence parallel in GPT2

* add docstring and note

* add requirments

* remove unused flash-attb

* modify flash attn test

* modify flash attn setting

* modify flash attn code

* add assert before divide, rename forward function

* [shardformer/test] fix gpt2 test with seq-parallel

* [shardformer/sequence parallel] Overlap input gather and grad computation during col backward (#4401)

* overlap gather input / grad computing during col backward

* modify test for overlap

* simplify code

* fix code and modify cuda stream synchronize

* [shardformer/sequence parallel] polish code
This commit is contained in:
Bin Jia
2023-08-16 15:41:20 +08:00
committed by GitHub
parent d20dceb9a3
commit 424629fea0
12 changed files with 655 additions and 65 deletions

View File

@@ -24,6 +24,8 @@ from colossalai.tensor.d_tensor.api import (
from ._operation import (
gather_forward_split_backward,
linear_gather_forward_reducescatter_backward,
linear_reducescatter_forward_gather_backward,
linear_with_async_comm,
reduce_forward,
split_forward_gather_backward,
@@ -50,6 +52,8 @@ class Linear1D_Col(ParallelModule):
gather_output (bool, optional): If true, call all-gather on output and make Y available
to all GPUs, otherwise, every GPU will have its output
which is :math:`Y_i = XA_i`, defaults to False
seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False.
overlap (`bool`): If set to ``True``, it will overlap input all-gather with gradient computation during backward, defaults to False.
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
which is preserved for kernel fusion, defaults to False
weight_initializer (`typing.Callable`):
@@ -69,6 +73,8 @@ class Linear1D_Col(ParallelModule):
device: torch.device = None,
process_group: ProcessGroup = None,
gather_output: bool = False,
seq_parallel: bool = False,
overlap: bool = False,
skip_bias_add: bool = False,
weight: Optional[Parameter] = None,
bias_: Optional[Parameter] = None,
@@ -80,6 +86,8 @@ class Linear1D_Col(ParallelModule):
self.in_features = in_features
self.out_features = out_features
self.gather_output = gather_output
self.seq_parallel = seq_parallel
self.overlap = overlap
self.skip_bias_add = skip_bias_add
self.device = device
self.process_group = process_group
@@ -180,7 +188,11 @@ class Linear1D_Col(ParallelModule):
# Matrix multiply.
bias = self.bias if not self.skip_bias_add else None
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True)
if self.seq_parallel:
output_parallel = linear_gather_forward_reducescatter_backward(input_parallel, self.weight, bias,
self.process_group, True, 1, self.overlap)
else:
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True)
if self.gather_output:
# All-gather across the partitions.
@@ -203,6 +215,8 @@ class Linear1D_Row(ParallelModule):
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False.
process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False.
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
which is preserved for kernel fusion, defaults to False
weight_initializer (:class:`typing.Callable`, optional):
@@ -221,6 +235,7 @@ class Linear1D_Row(ParallelModule):
dtype: torch.dtype = None,
device: torch.device = None,
process_group: ProcessGroup = None,
seq_parallel: bool = False,
parallel_input: bool = True,
skip_bias_add: bool = False,
weight: Optional[Parameter] = None,
@@ -238,6 +253,7 @@ class Linear1D_Row(ParallelModule):
self.parallel_input = parallel_input
self.skip_bias_add = skip_bias_add
self.process_group = process_group
self.seq_parallel = seq_parallel
self.num_partitions = dist.get_world_size(self.process_group)
if skip_bias_add and not bias:
@@ -373,7 +389,10 @@ class Linear1D_Row(ParallelModule):
output = torch.cat(output_parallel_list, dim=-1)
else:
output_parallel = F.linear(input_, self.weight)
output = reduce_forward(output_parallel, self.process_group)
if self.seq_parallel:
output = linear_reducescatter_forward_gather_backward(output_parallel, self.process_group, 1)
else:
output = reduce_forward(output_parallel, self.process_group)
if not self.skip_bias_add:
if self.bias is not None: