mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 20:10:17 +00:00
[shardformer] optimize seq parallelism (#6086)
* [shardformer] optimize seq parallelism * [shardformer] fix gpt2 fused linear col * [plugin] update gemini plugin * [plugin] update moe hybrid plugin * [test] update gpt2 fused linear test * [shardformer] fix gpt2 fused linear reduce
This commit is contained in:
@@ -23,17 +23,15 @@ from colossalai.tensor.d_tensor.api import (
|
||||
)
|
||||
|
||||
from ._operation import (
|
||||
gather_forward_reducescatter_backward,
|
||||
gather_forward_split_backward,
|
||||
linear_gather_forward_reducescatter_backward,
|
||||
linear_reducescatter_forward_gather_backward,
|
||||
linear_with_async_comm,
|
||||
reduce_forward,
|
||||
reducescatter_forward_gather_backward,
|
||||
split_forward_gather_backward,
|
||||
)
|
||||
from .parallel_module import PaddingParallelModule, ParallelModule
|
||||
from .utils import create_randomizer_with_offset
|
||||
from .utils import create_randomizer_with_offset, is_share_sp_tp
|
||||
|
||||
__all__ = ["Linear1D_Col", "Linear1D_Row"]
|
||||
|
||||
@@ -55,7 +53,6 @@ class Linear1D_Col(ParallelModule):
|
||||
to all GPUs, otherwise, every GPU will have its output
|
||||
which is :math:`Y_i = XA_i`, defaults to False
|
||||
seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False.
|
||||
overlap (`bool`): If set to ``True``, it will overlap input all-gather with gradient computation during backward, defaults to False.
|
||||
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
|
||||
which is preserved for kernel fusion, defaults to False
|
||||
weight_initializer (`typing.Callable`):
|
||||
@@ -78,7 +75,6 @@ class Linear1D_Col(ParallelModule):
|
||||
gather_output: bool = False,
|
||||
seq_parallel_mode: str = None,
|
||||
seq_parallel_dim: int = 1,
|
||||
overlap: torch.cuda.Stream = None,
|
||||
skip_bias_add: bool = False,
|
||||
weight: Optional[Parameter] = None,
|
||||
bias_: Optional[Parameter] = None,
|
||||
@@ -95,7 +91,6 @@ class Linear1D_Col(ParallelModule):
|
||||
self.gather_output = gather_output
|
||||
self.seq_parallel_mode = seq_parallel_mode
|
||||
self.seq_parallel_dim = seq_parallel_dim
|
||||
self.overlap = overlap
|
||||
self.skip_bias_add = skip_bias_add
|
||||
self.device = device
|
||||
self.process_group = process_group
|
||||
@@ -202,16 +197,15 @@ class Linear1D_Col(ParallelModule):
|
||||
# Matrix multiply.
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
|
||||
if self.seq_parallel_mode == "split_gather":
|
||||
input_parallel = gather_forward_reducescatter_backward(
|
||||
input_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication
|
||||
)
|
||||
output_parallel = linear_with_async_comm(
|
||||
input_parallel, self.weight, bias, self.process_group, False, fp8_communication=self.fp8_communication
|
||||
)
|
||||
elif self.seq_parallel_mode == "ring":
|
||||
if is_share_sp_tp(self.seq_parallel_mode):
|
||||
output_parallel = linear_gather_forward_reducescatter_backward(
|
||||
input_parallel, self.weight, bias, self.process_group, True, self.seq_parallel_dim, self.overlap, True
|
||||
input_parallel,
|
||||
self.weight,
|
||||
bias,
|
||||
self.process_group,
|
||||
True,
|
||||
self.seq_parallel_dim,
|
||||
ring=self.seq_parallel_mode == "ring",
|
||||
)
|
||||
else:
|
||||
output_parallel = linear_with_async_comm(
|
||||
@@ -428,18 +422,13 @@ class Linear1D_Row(ParallelModule):
|
||||
handle.wait()
|
||||
output = torch.cat(output_parallel_list, dim=-1)
|
||||
else:
|
||||
if self.seq_parallel_mode == "split_gather":
|
||||
output_parallel = F.linear(input_, self.weight)
|
||||
output = reducescatter_forward_gather_backward(
|
||||
output_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication
|
||||
)
|
||||
elif self.seq_parallel_mode == "ring":
|
||||
if is_share_sp_tp(self.seq_parallel_mode):
|
||||
output = linear_reducescatter_forward_gather_backward(
|
||||
input_,
|
||||
self.weight,
|
||||
process_group=self.process_group,
|
||||
dim=self.seq_parallel_dim,
|
||||
ring=True,
|
||||
ring=self.seq_parallel_mode == "ring",
|
||||
)
|
||||
else:
|
||||
output_parallel = F.linear(input_, self.weight)
|
||||
@@ -551,7 +540,6 @@ class VocabParallelLMHead1D(Linear1D_Col, PaddingParallelModule):
|
||||
to all GPUs, otherwise, every GPU will have its output
|
||||
which is :math:`Y_i = XA_i`, defaults to False
|
||||
seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False.
|
||||
overlap (`bool`): If set to ``True``, it will overlap input all-gather with gradient computation during backward, defaults to False.
|
||||
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
|
||||
which is preserved for kernel fusion, defaults to False
|
||||
weight_initializer (`typing.Callable`):
|
||||
|
Reference in New Issue
Block a user