[shardformer] optimize seq parallelism (#6086)

* [shardformer] optimize seq parallelism

* [shardformer] fix gpt2 fused linear col

* [plugin] update gemini plugin

* [plugin] update moe hybrid plugin

* [test] update gpt2 fused linear test

* [shardformer] fix gpt2 fused linear reduce
This commit is contained in:
Hongxin Liu
2024-10-11 13:44:40 +08:00
committed by GitHub
parent 6b2c506fc5
commit dc2cdaf3e8
13 changed files with 111 additions and 278 deletions

View File

@@ -23,17 +23,15 @@ from colossalai.tensor.d_tensor.api import (
)
from ._operation import (
gather_forward_reducescatter_backward,
gather_forward_split_backward,
linear_gather_forward_reducescatter_backward,
linear_reducescatter_forward_gather_backward,
linear_with_async_comm,
reduce_forward,
reducescatter_forward_gather_backward,
split_forward_gather_backward,
)
from .parallel_module import PaddingParallelModule, ParallelModule
from .utils import create_randomizer_with_offset
from .utils import create_randomizer_with_offset, is_share_sp_tp
__all__ = ["Linear1D_Col", "Linear1D_Row"]
@@ -55,7 +53,6 @@ class Linear1D_Col(ParallelModule):
to all GPUs, otherwise, every GPU will have its output
which is :math:`Y_i = XA_i`, defaults to False
seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False.
overlap (`bool`): If set to ``True``, it will overlap input all-gather with gradient computation during backward, defaults to False.
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
which is preserved for kernel fusion, defaults to False
weight_initializer (`typing.Callable`):
@@ -78,7 +75,6 @@ class Linear1D_Col(ParallelModule):
gather_output: bool = False,
seq_parallel_mode: str = None,
seq_parallel_dim: int = 1,
overlap: torch.cuda.Stream = None,
skip_bias_add: bool = False,
weight: Optional[Parameter] = None,
bias_: Optional[Parameter] = None,
@@ -95,7 +91,6 @@ class Linear1D_Col(ParallelModule):
self.gather_output = gather_output
self.seq_parallel_mode = seq_parallel_mode
self.seq_parallel_dim = seq_parallel_dim
self.overlap = overlap
self.skip_bias_add = skip_bias_add
self.device = device
self.process_group = process_group
@@ -202,16 +197,15 @@ class Linear1D_Col(ParallelModule):
# Matrix multiply.
bias = self.bias if not self.skip_bias_add else None
if self.seq_parallel_mode == "split_gather":
input_parallel = gather_forward_reducescatter_backward(
input_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication
)
output_parallel = linear_with_async_comm(
input_parallel, self.weight, bias, self.process_group, False, fp8_communication=self.fp8_communication
)
elif self.seq_parallel_mode == "ring":
if is_share_sp_tp(self.seq_parallel_mode):
output_parallel = linear_gather_forward_reducescatter_backward(
input_parallel, self.weight, bias, self.process_group, True, self.seq_parallel_dim, self.overlap, True
input_parallel,
self.weight,
bias,
self.process_group,
True,
self.seq_parallel_dim,
ring=self.seq_parallel_mode == "ring",
)
else:
output_parallel = linear_with_async_comm(
@@ -428,18 +422,13 @@ class Linear1D_Row(ParallelModule):
handle.wait()
output = torch.cat(output_parallel_list, dim=-1)
else:
if self.seq_parallel_mode == "split_gather":
output_parallel = F.linear(input_, self.weight)
output = reducescatter_forward_gather_backward(
output_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication
)
elif self.seq_parallel_mode == "ring":
if is_share_sp_tp(self.seq_parallel_mode):
output = linear_reducescatter_forward_gather_backward(
input_,
self.weight,
process_group=self.process_group,
dim=self.seq_parallel_dim,
ring=True,
ring=self.seq_parallel_mode == "ring",
)
else:
output_parallel = F.linear(input_, self.weight)
@@ -551,7 +540,6 @@ class VocabParallelLMHead1D(Linear1D_Col, PaddingParallelModule):
to all GPUs, otherwise, every GPU will have its output
which is :math:`Y_i = XA_i`, defaults to False
seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False.
overlap (`bool`): If set to ``True``, it will overlap input all-gather with gradient computation during backward, defaults to False.
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
which is preserved for kernel fusion, defaults to False
weight_initializer (`typing.Callable`):