[Sharderformer] Support zbv in Sharderformer Policy (#6150)

* [feat] Sharderformer support zbv

* [feat] support chatglm2, command, deepseek for zbv

* [feat] support zbv in shardformer policy:
falcon,gptj,mistral,opt,qwen2,t5, vit, whisper

* [feat] support GPT2FusedLinearConv1D

* [feat] support GPT2FusedLinear (without tp)

* [fix] debug FusedConvLinear

* [shardfromer] support gpt2 policy for zbv, support GPT2FusedLinearConv
Col and Row.

* [Shardformer] support FusedLinear1D base for zbv

* [shardformer] support zbv in FusedLinear1D base, Col, Row

* [shardformer] support zbv in blip2 and sam policy

* [shardformer] fix bug incorrect number of gradients; add fusedLinear
base testcase;

* [fix] fix incorrect number of gradients ;

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [Shardformer] add en doc for zbv;

* [fix] fix typo in Model compatibility table

* [fix] fix API Reference typo

* [Shardformer] add zh-Han doc for zbv

* [fix] fix Linear name; update en & zh doc

* [fix] fix shardformer doc import err

* [fix] fix shardconfig import in doc

* [fix] fix shardformer doc

* [fix] fix shardconfig doc

* [fix] fix config

* [fix] remove shardconfig

* [fix] fix doc

* [feat] add zbv doc string

* [fix] rm doc

* [fix] fix doc

* [fix] empty zbv doc

* [fix] ifx torch version

* [fix] fix torch version

* [fix] fix torch versions

* [fix] fix torch versions

* [fix] fix pyramid versions

* [fix] fix pyramid, zope version

* [fix] try fix workflow

* [fix] try import ShardConfig in yml

* [fix] fix workflow

* [fix] fix workflow

* [fix] fix workflow

* [fix] fix workflow

* [fix] fix ci

* [fix] fix zbv doc

* [fix] fix param for qkv linear, gpt2fused linear; fix requirments;

* [fix] fix policy use fused_linear

* [fix] fix weight grad none, err caused by  weight ptr change

* [fix] fix comm in WeightGradStore

* [fix] fix WeightGradStore pop param

* [fix] remove useless param in doc; fix gpt2 qkv test;

* [shardformer] simplify execute_w_pass_grad_accum;

* [fix] rm useless comments

* [shardformer] simplify execute_w_pass_grad_accum & execute_w_pass

* [shardformer] Run meaningful doc test

* [shadformer] fix doc test cmd;

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
duanjunwen
2025-01-02 10:22:26 +08:00
committed by GitHub
parent af06d162cf
commit a9bedc7a43
27 changed files with 3511 additions and 316 deletions

View File

@@ -7,7 +7,6 @@ from typing import Callable, List, Optional, Tuple, Union
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
@@ -28,8 +27,10 @@ from ._operation import (
linear_gather_forward_reducescatter_backward,
linear_reducescatter_forward_gather_backward,
linear_with_async_comm,
linear_with_grad_accum,
matmul_gather_forward_reducescatter_backward,
matmul_with_async_comm,
matmul_with_grad_comm,
reduce_forward,
reducescatter_forward_gather_backward,
split_forward_gather_backward,
@@ -37,7 +38,14 @@ from ._operation import (
from .parallel_module import ParallelModule
from .utils import create_randomizer_with_offset, is_share_sp_tp
__all__ = ["FusedLinear1D_Col", "FusedLinear1D_Row", "GPT2FusedLinearConv1D_Col", "GPT2FusedLinearConv1D_Row"]
__all__ = [
"FusedLinear1D_Col",
"FusedLinear1D_Row",
"FusedLinear",
"GPT2FusedLinearConv1D_Col",
"GPT2FusedLinearConv1D_Row",
"GPT2FusedLinearConv",
]
# ====================================
# For GPT Only
@@ -228,6 +236,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
fp8_communication: bool = False,
use_zbv: bool = False,
):
super().__init__()
@@ -241,6 +250,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
self.split_sizes = split_sizes
self.process_group = process_group
self.fp8_communication = fp8_communication
self.use_zbv = use_zbv
assert (
sum(split_sizes) == out_features
@@ -375,6 +385,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
1,
ring=self.seq_parallel_mode == "ring",
fp8_communication=self.fp8_communication,
use_zbv=self.use_zbv,
)
elif self.seq_parallel_mode is None or self.seq_parallel_mode == "ring_attn":
# Set up backprop all-reduce.
@@ -386,6 +397,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
self.process_group,
True,
fp8_communication=self.fp8_communication,
use_zbv=self.use_zbv,
)
else:
raise NotImplementedError(f"seq_parallel_mode={self.seq_parallel_mode} is not supported!")
@@ -441,6 +453,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
stream_chunk_num: int = 1,
fp8_communication: bool = False,
use_zbv: bool = False,
):
super().__init__()
@@ -455,6 +468,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
self.seq_parallel_mode = seq_parallel_mode
self.num_partitions = dist.get_world_size(self.process_group)
self.fp8_communication = fp8_communication
self.use_zbv = use_zbv
if skip_bias_add and not bias:
raise ValueError("cannot skip bias addition if bias is None")
@@ -620,6 +634,152 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
return output, self.bias
class GPT2FusedLinearConv(ParallelModule):
r"""Linear layer without parallelism.
This layer is used to fit `Conv1D` layer (Fused QKV) in gpt2 of huggingface.
Args:
in_features (int): size of each input sample.
out_features (int): size of each output sample.
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
seq_parallel_mode (str): If set to ``None``, it will not use sequence parallel, otherwise will use corresponding mode of sequence parallel, defaults to None.
which is preserved for kernel fusion, defaults to False
weight_initializer (:class:`typing.Callable`, optional):
The initializer of weight, defaults to kaiming uniform initializer.
bias_initializer (:class:`typing.Callable`, optional):
The initializer of bias, defaults to xavier uniform initializer.
More details about ``initializer`` please refer to
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
"""
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
dtype: torch.dtype = None,
device: torch.device = None,
seq_parallel_mode: str = None,
seq_parallel_dim: int = 1,
skip_bias_add: bool = False,
weight: Optional[Parameter] = None,
bias_: Optional[Parameter] = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
use_zbv: bool = False,
):
super().__init__()
# Keep input parameters
self.in_features = in_features
self.out_features = out_features
self.seq_parallel_mode = seq_parallel_mode
self.seq_parallel_dim = seq_parallel_dim
self.skip_bias_add = skip_bias_add
self.device = device
self.use_zbv = use_zbv
if skip_bias_add and not bias:
raise ValueError("cannot skip bias addition if bias is None")
# offset the seed with randomizer index and rank
seed = torch.random.initial_seed()
self.randomizer = create_randomizer_with_offset(seed, None)
# sanity check
if weight is not None:
assert not bias or bias_ is not None, "bias_ must be provided if bias is True when weight is not None"
else:
assert bias_ is None, "bias_ must be None if weight is None"
# Parameters.
if weight is None:
# Initialize weight.
factory_kwargs = {"device": device, "dtype": dtype}
self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs))
else:
weight.data = weight.data.to(device=device, dtype=dtype)
self.weight = weight
if bias:
if bias_ is None:
self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs))
else:
bias_.data = bias_.data.to(device=device, dtype=dtype)
self.bias = bias_
else:
self.bias = None
if weight is None:
# init weights
self.reset_parameters(weight_initializer, bias_initializer)
@staticmethod
def from_native_module(
module: nn.Module,
*args,
**kwargs,
) -> ParallelModule:
r"""
Convert a huggingface layer `Conv1D` in gpt2 to a parallelized linear layer.
Args:
module (`nn.Linear`): The module to be converted.
split_sizes (List[int]): The sizes of the split tensor. In GPT2, Q,K,V are fused in one weight.
"""
LazyInitContext.materialize(module)
# get the attributes
in_features = module.weight.shape[0]
out_features = module.weight.shape[1]
bias = module.bias is not None
device = module.weight.device
linear_1d = GPT2FusedLinearConv(
in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
weight=module.weight,
bias_=module.bias,
*args,
**kwargs,
)
return linear_1d
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
with self.randomizer.fork_rng(enable_cpu=True):
fan_in, fan_out = self.in_features, self.out_features
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
if self.bias is not None:
bias_initializer(self.bias, fan_in=fan_in)
def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
# Matrix multiply.
bias = self.bias if not self.skip_bias_add else None
if self.seq_parallel_mode is None or self.seq_parallel_mode == "ring_attn":
# Set up backprop all-reduce.
input_parallel = input_
output_parallel = matmul_with_grad_comm(
input_parallel,
self.weight,
bias,
False,
self.use_zbv,
)
else:
raise NotImplementedError(f"seq_parallel_mode={self.seq_parallel_mode} is not supported!")
output = output_parallel
if self.skip_bias_add:
return output, self.bias
else:
return output
# ====================================
# For Fused torch.nn.Linear
# ====================================
@@ -671,6 +831,7 @@ class FusedLinear1D_Col(ParallelModule):
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
fp8_communication: bool = False,
use_zbv: bool = False,
):
super().__init__()
# Keep input parameters
@@ -684,6 +845,7 @@ class FusedLinear1D_Col(ParallelModule):
self.split_sizes = split_sizes
self.process_group = process_group
self.fp8_communication = fp8_communication
self.use_zbv = use_zbv
assert (
sum(split_sizes) == out_features
@@ -811,10 +973,17 @@ class FusedLinear1D_Col(ParallelModule):
True,
self.seq_parallel_dim,
ring=self.seq_parallel_mode == "ring",
use_zbv=self.use_zbv,
)
else:
output_parallel = linear_with_async_comm(
input_parallel, self.weight, bias, self.process_group, True, fp8_communication=self.fp8_communication
input_parallel,
self.weight,
bias,
self.process_group,
True,
fp8_communication=self.fp8_communication,
use_zbv=self.use_zbv,
)
if self.gather_output:
@@ -870,6 +1039,7 @@ class FusedLinear1D_Row(ParallelModule):
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
fp8_communication: bool = False,
use_zbv: bool = False,
):
super().__init__()
# Keep input parameters
@@ -883,6 +1053,7 @@ class FusedLinear1D_Row(ParallelModule):
self.seq_parallel_dim = seq_parallel_dim
self.num_partitions = dist.get_world_size(self.process_group)
self.fp8_communication = fp8_communication
self.use_zbv = use_zbv
assert (
sum(split_sizes) == in_features
@@ -1009,9 +1180,18 @@ class FusedLinear1D_Row(ParallelModule):
process_group=self.process_group,
dim=self.seq_parallel_dim,
ring=self.seq_parallel_mode == "ring",
use_zbv=self.use_zbv,
)
else:
output_parallel = F.linear(input_, self.weight)
# output_parallel = F.linear(input_, self.weight) # Replace to LinearWithGradAccum
output_parallel = linear_with_grad_accum(
input_,
self.weight,
None,
False,
use_zbv=self.use_zbv,
)
output = reduce_forward(output_parallel, self.process_group, fp8_communication=self.fp8_communication)
if not self.skip_bias_add:
@@ -1020,3 +1200,156 @@ class FusedLinear1D_Row(ParallelModule):
return output
else:
return output, self.bias
class FusedLinear(ParallelModule):
r"""Fused Linear layer with column parallelism.
The linear layer is defined as :math:`Y = XA + b`. A is parallelized along
its second dimension as :math:`A = [A_1, ..., A_p]`. This layer is used to fit `torch.nn.Linear` layer (Fused QKV) in normal torch layer of huggingface, like SAM.
Args:
in_features (int): size of each input sample.
out_features (int): size of each output sample.
split_sizes (List[int]): The sizes of the split tensor.
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
device (`torch.device`): The device of parameters, defaults to None.
process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
gather_output (bool, optional): If true, call all-gather on output and make Y available
to all GPUs, otherwise, every GPU will have its output
which is :math:`Y_i = XA_i`, defaults to False
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
which is preserved for kernel fusion, defaults to False
weight_initializer (`typing.Callable`):
The initializer of weight, defaults to kaiming uniform initializer.
bias_initializer (`typing.Callable`):
The initializer of bias, defaults to xavier uniform initializer.
More details about ``initializer`` please refer to
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
"""
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
dtype: torch.dtype = None,
device: torch.device = None,
seq_parallel_mode: str = None,
seq_parallel_dim: int = 1,
skip_bias_add: bool = False,
weight: Optional[Parameter] = None,
bias_: Optional[Parameter] = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
use_zbv: bool = False,
):
super().__init__()
# Keep input parameters
self.in_features = in_features
self.out_features = out_features
self.seq_parallel_mode = seq_parallel_mode
self.seq_parallel_dim = seq_parallel_dim
self.skip_bias_add = skip_bias_add
self.device = device
self.use_zbv = use_zbv
if skip_bias_add and not bias:
raise ValueError("cannot skip bias addition if bias is None")
# offset the seed with randomizer index and rank
seed = torch.random.initial_seed()
self.randomizer = create_randomizer_with_offset(seed, process_group=None)
# sanity check
if weight is not None:
assert not bias or bias_ is not None, "bias_ must be provided if bias is True when weight is not None"
else:
assert bias_ is None, "bias_ must be None if weight is None"
# Parameters.
if weight is None:
# Initialize weight.
factory_kwargs = {"device": device, "dtype": dtype}
self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs))
else:
weight.data = weight.data.to(device=device, dtype=dtype)
self.weight = weight
if bias:
if bias_ is None:
self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs))
else:
bias_.data = bias_.data.to(device=device, dtype=dtype)
self.bias = bias_
else:
self.bias = None
if weight is None:
# init weights
self.reset_parameters(weight_initializer, bias_initializer)
@staticmethod
def from_native_module(
module: nn.Module,
*args,
**kwargs,
) -> ParallelModule:
r"""
Convert a fused `torch.nn.linear` layer to a parallelized linear layer.
Args:
module (`nn.Linear`): The module to be converted.
process_group (`Union[ProcessGroup, List[ProcessGroup]]`): The process group to be used for weight sharding and communication.
split_sizes (List[int]): The sizes of the split tensor. In common, Q,K,V are fused in one weight.
"""
LazyInitContext.materialize(module)
# get the attributes
in_features = module.in_features
out_features = module.out_features
bias = module.bias is not None
device = module.weight.device
linear_1d = FusedLinear(
in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
weight=module.weight,
bias_=module.bias,
*args,
**kwargs,
)
return linear_1d
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
with self.randomizer.fork_rng(enable_cpu=True):
fan_in, fan_out = self.in_features, self.out_features
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
if self.bias is not None:
bias_initializer(self.bias, fan_in=fan_in)
def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
assert (
input_.shape[-1] == self.weight.shape[-1]
), "Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.".format(
input_.shape, self.weight.shape, self.weight.shape[-1]
)
# Set up backprop all-reduce.
input_parallel = input_
# Matrix multiply.
bias = self.bias if not self.skip_bias_add else None
output_parallel = linear_with_grad_accum(input_parallel, self.weight, bias, True, use_zbv=self.use_zbv)
output = output_parallel
if self.skip_bias_add:
return output, self.bias
else:
return output