mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 04:24:47 +00:00
[shardformer/sequence parallel] Cherry pick commit to new branch (#4450)
* [shardformer/sequence parallel] Support sequence parallel for gpt2 (#4384) * [sequence parallel] add sequence parallel linear col/row support (#4336) * add sequence parallel linear col/row support * add annotation * add annotation * add support for gpt2 fused qkv linear layer * support sequence parallel in GPT2 * add docstring and note * add requirments * remove unused flash-attb * modify flash attn test * modify flash attn setting * modify flash attn code * add assert before divide, rename forward function * [shardformer/test] fix gpt2 test with seq-parallel * [shardformer/sequence parallel] Overlap input gather and grad computation during col backward (#4401) * overlap gather input / grad computing during col backward * modify test for overlap * simplify code * fix code and modify cuda stream synchronize * [shardformer/sequence parallel] polish code
This commit is contained in:
@@ -24,6 +24,8 @@ from colossalai.tensor.d_tensor.api import (
|
||||
|
||||
from ._operation import (
|
||||
gather_forward_split_backward,
|
||||
linear_gather_forward_reducescatter_backward,
|
||||
linear_reducescatter_forward_gather_backward,
|
||||
linear_with_async_comm,
|
||||
reduce_forward,
|
||||
split_forward_gather_backward,
|
||||
@@ -50,6 +52,8 @@ class Linear1D_Col(ParallelModule):
|
||||
gather_output (bool, optional): If true, call all-gather on output and make Y available
|
||||
to all GPUs, otherwise, every GPU will have its output
|
||||
which is :math:`Y_i = XA_i`, defaults to False
|
||||
seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False.
|
||||
overlap (`bool`): If set to ``True``, it will overlap input all-gather with gradient computation during backward, defaults to False.
|
||||
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
|
||||
which is preserved for kernel fusion, defaults to False
|
||||
weight_initializer (`typing.Callable`):
|
||||
@@ -69,6 +73,8 @@ class Linear1D_Col(ParallelModule):
|
||||
device: torch.device = None,
|
||||
process_group: ProcessGroup = None,
|
||||
gather_output: bool = False,
|
||||
seq_parallel: bool = False,
|
||||
overlap: bool = False,
|
||||
skip_bias_add: bool = False,
|
||||
weight: Optional[Parameter] = None,
|
||||
bias_: Optional[Parameter] = None,
|
||||
@@ -80,6 +86,8 @@ class Linear1D_Col(ParallelModule):
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.gather_output = gather_output
|
||||
self.seq_parallel = seq_parallel
|
||||
self.overlap = overlap
|
||||
self.skip_bias_add = skip_bias_add
|
||||
self.device = device
|
||||
self.process_group = process_group
|
||||
@@ -180,7 +188,11 @@ class Linear1D_Col(ParallelModule):
|
||||
|
||||
# Matrix multiply.
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True)
|
||||
if self.seq_parallel:
|
||||
output_parallel = linear_gather_forward_reducescatter_backward(input_parallel, self.weight, bias,
|
||||
self.process_group, True, 1, self.overlap)
|
||||
else:
|
||||
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True)
|
||||
|
||||
if self.gather_output:
|
||||
# All-gather across the partitions.
|
||||
@@ -203,6 +215,8 @@ class Linear1D_Row(ParallelModule):
|
||||
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
|
||||
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
|
||||
parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False.
|
||||
process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
|
||||
seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False.
|
||||
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
|
||||
which is preserved for kernel fusion, defaults to False
|
||||
weight_initializer (:class:`typing.Callable`, optional):
|
||||
@@ -221,6 +235,7 @@ class Linear1D_Row(ParallelModule):
|
||||
dtype: torch.dtype = None,
|
||||
device: torch.device = None,
|
||||
process_group: ProcessGroup = None,
|
||||
seq_parallel: bool = False,
|
||||
parallel_input: bool = True,
|
||||
skip_bias_add: bool = False,
|
||||
weight: Optional[Parameter] = None,
|
||||
@@ -238,6 +253,7 @@ class Linear1D_Row(ParallelModule):
|
||||
self.parallel_input = parallel_input
|
||||
self.skip_bias_add = skip_bias_add
|
||||
self.process_group = process_group
|
||||
self.seq_parallel = seq_parallel
|
||||
self.num_partitions = dist.get_world_size(self.process_group)
|
||||
|
||||
if skip_bias_add and not bias:
|
||||
@@ -373,7 +389,10 @@ class Linear1D_Row(ParallelModule):
|
||||
output = torch.cat(output_parallel_list, dim=-1)
|
||||
else:
|
||||
output_parallel = F.linear(input_, self.weight)
|
||||
output = reduce_forward(output_parallel, self.process_group)
|
||||
if self.seq_parallel:
|
||||
output = linear_reducescatter_forward_gather_backward(output_parallel, self.process_group, 1)
|
||||
else:
|
||||
output = reduce_forward(output_parallel, self.process_group)
|
||||
|
||||
if not self.skip_bias_add:
|
||||
if self.bias is not None:
|
||||
|
Reference in New Issue
Block a user