[shardformer] support SAM (#4231)

* 1.support sam 2.add fused qkv for nn.Linear

* update utils support set element in list

* overtwrite SamVisionAttention foward to use DropoutForParallelInput

* remove unused code
This commit is contained in:
FoolPlayer
2023-07-14 15:56:59 +08:00
committed by Hongxin Liu
parent c59d7aca09
commit dd2bf02679
10 changed files with 733 additions and 10 deletions

View File

@@ -25,6 +25,7 @@ from colossalai.tensor.d_tensor.api import (
from ._operation import (
gather_forward_split_backward,
linear_with_async_comm,
matmul_with_async_comm,
reduce_backward,
reduce_forward,
@@ -33,7 +34,7 @@ from ._operation import (
from .parallel_module import ParallelModule
from .utils import create_randomizer_with_offset
__all__ = ['FusedLinear1D_Col', 'FusedLinear1D_Row']
__all__ = ['FusedLinear1D_Col', 'FusedLinear1D_Row', 'GPT2FusedLinearConv1D_Col', 'GPT2FusedLinearConv1D_Row']
# ====================================
# For GPT Only
@@ -490,3 +491,175 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
return output
else:
return output, self.bias
# ====================================
# For Fused torch.nn.Linear
# ====================================
class FusedLinear1D_Col(ParallelModule):
r"""Fused Linear layer with column parallelism.
The linear layer is defined as :math:`Y = XA + b`. A is parallelized along
its second dimension as :math:`A = [A_1, ..., A_p]`. This layer is used to fit `torch.nn.Linear` layer (Fused QKV) in normal torch layer of huggingface, like SAM.
Args:
in_features (int): size of each input sample.
out_features (int): size of each output sample.
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
device (`torch.device`): The device of parameters, defaults to None.
n_fused (int): The number items fused, defaults to 3 (QKV).
process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
gather_output (bool, optional): If true, call all-gather on output and make Y available
to all GPUs, otherwise, every GPU will have its output
which is :math:`Y_i = XA_i`, defaults to False
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
which is preserved for kernel fusion, defaults to False
weight_initializer (`typing.Callable`):
The initializer of weight, defaults to kaiming uniform initializer.
bias_initializer (`typing.Callable`):
The initializer of bias, defaults to xavier uniform initializer.
More details about ``initializer`` please refer to
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
"""
def __init__(self,
in_features: int,
out_features: int,
bias: bool = True,
dtype: torch.dtype = None,
device: torch.device = None,
process_group: ProcessGroup = None,
async_communication: bool = False,
gather_output: bool = False,
skip_bias_add: bool = False,
n_fused: int = 3,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
super().__init__()
# Keep input parameters
self.in_features = in_features
self.out_features = out_features
self.gather_output = gather_output
self.skip_bias_add = skip_bias_add
self.device = device
self.n_fused = n_fused
self.process_group = process_group
self.async_communication = async_communication
if skip_bias_add and not bias:
raise ValueError('cannot skip bias addition if bias is None')
# Parameters.
# Initialize weight.
factory_kwargs = {'device': device, 'dtype': dtype}
weight = torch.empty(self.out_features, self.in_features, **factory_kwargs)
def shard_fn(tensor):
return split_fused_qkv_in_gpt2_style(tensor, self.n_fused, self.process_group, False)
def gather_fn(tensor):
return gather_fused_qkv_in_gpt2_style(tensor, 3, self.process_group, False)
with torch.no_grad():
sharded_weight = distribute_tensor_with_customization(weight, shard_fn, gather_fn)
self.weight = customized_distributed_tensor_to_param(sharded_weight)
if bias:
bias = torch.empty(self.out_features, **factory_kwargs)
with torch.no_grad():
sharded_bias = distribute_tensor_with_customization(bias, shard_fn, gather_fn)
self.bias = customized_distributed_tensor_to_param(sharded_bias)
else:
self.bias = None
# offset the seed with randomizer index and rank
seed = torch.random.initial_seed()
self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
# init weights
self.reset_parameters(weight_initializer, bias_initializer)
@staticmethod
def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], n_fused: int,
*args, **kwargs) -> ParallelModule:
r"""
Convert a fused `torch.nn.linear` layer to a parallelized linear layer.
Args:
module (`nn.Linear`): The module to be converted.
process_group (`Union[ProcessGroup, List[ProcessGroup]]`): The process group to be used for weight sharding and communication.
n_fused (int): The number of layers to be fused. In common, Q,K,V are fused in one weight.
"""
# get the attributes
in_features = module.in_features
out_features = module.out_features
bias = module.bias is not None
device = module.weight.device
# ensure only one process group is passed
if isinstance(process_group, (list, tuple)):
assert len(process_group) == 1, \
f'Expected only one process group, got {len(process_group)}.'
process_group = process_group[0]
linear_1d = FusedLinear1D_Col(in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
process_group=process_group,
*args,
**kwargs)
# TODO: copy the sharded weights
with torch.no_grad():
sharded_weight = split_fused_qkv_in_gpt2_style(module.weight.data,
n_fused=n_fused,
process_group=process_group,
is_transposed=False)
linear_1d.weight.data.copy_(sharded_weight.data)
if bias:
sharded_bias = split_fused_qkv_in_gpt2_style(module.bias.data,
n_fused=n_fused,
process_group=process_group,
is_transposed=False)
linear_1d.bias.data.copy_(sharded_bias.data)
return linear_1d
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
with self.randomizer.fork_rng(enable_cpu=True):
fan_in, fan_out = self.in_features, self.out_features
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
if self.bias is not None:
bias_initializer(self.bias, fan_in=fan_in)
def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
assert input_.shape[-1] == self.weight.shape[-1], \
'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format(
input_.shape, self.weight.shape, self.weight.shape[-1])
# Set up backprop all-reduce.
# input_parallel = reduce_backward(input_, self.process_group)
input_parallel = input_
# Matrix multiply.
bias = self.bias if not self.skip_bias_add else None
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True)
if self.gather_output:
# All-gather across the partitions.
output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group)
else:
output = output_parallel
if self.skip_bias_add:
return output, self.bias
else:
return output