[shardformer] support inplace sharding (#4251)

* [shardformer] embedding support inplace sharding

* [shardformer] linear support inplace sharding

* [shardformer] layernorm support inplace sharding

* [shardformer] qkv support inplace sharding

* [test] update shardformer layer test

* [shardformer] fix shared param sharding

* [shardformer] fix bert policy

* [shardformer] fix bloom policy

* [shardformer] fix llama policy

* [shardformer] fix opt policy

* [shardformer] fix t5 policy

* [shardformer] fix fused qkv linear

* [shardformer] fix bugs

* force sync

* [test] fix bugs

* [test] fix transformer version
This commit is contained in:
Hongxin Liu
2023-07-20 10:39:06 +08:00
parent 2a2eacfaf1
commit d921ce8391
26 changed files with 371 additions and 340 deletions

View File

@@ -2,7 +2,7 @@
# -*- encoding: utf-8 -*-
import math
from typing import Callable, List, Tuple, Union
from typing import Callable, List, Optional, Tuple, Union
import torch
import torch.distributed as dist
@@ -15,7 +15,12 @@ from torch.nn.parameter import Parameter
from colossalai.lazy import LazyInitContext
from colossalai.nn import init as init
from colossalai.nn.layer.utils import divide
from colossalai.tensor.d_tensor import shard_colwise, shard_rowwise, sharded_tensor_to_param
from colossalai.tensor.d_tensor.api import (
is_distributed_tensor,
shard_colwise,
shard_rowwise,
sharded_tensor_to_existing_param,
)
from ._operation import (
gather_forward_split_backward,
@@ -65,6 +70,8 @@ class Linear1D_Col(ParallelModule):
process_group: ProcessGroup = None,
gather_output: bool = False,
skip_bias_add: bool = False,
weight: Optional[Parameter] = None,
bias_: Optional[Parameter] = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
super().__init__()
@@ -80,26 +87,42 @@ class Linear1D_Col(ParallelModule):
if skip_bias_add and not bias:
raise ValueError('cannot skip bias addition if bias is None')
# Parameters.
factory_kwargs = {'device': device, 'dtype': dtype}
weight = torch.empty(self.out_features, self.in_features, **factory_kwargs)
sharded_weight = shard_rowwise(weight, self.process_group)
self.weight = sharded_tensor_to_param(sharded_weight)
if bias:
bias = torch.empty(self.out_features, **factory_kwargs)
sharded_bias = shard_colwise(bias, self.process_group)
self.bias = sharded_tensor_to_param(sharded_bias)
else:
self.bias = None
# offset the seed with randomizer index and rank
seed = torch.random.initial_seed()
self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
# init weights
self.reset_parameters(weight_initializer, bias_initializer)
# sanity check
if weight is not None:
assert not bias or bias_ is not None, 'bias_ must be provided if bias is True when weight is not None'
else:
assert bias_ is None, 'bias_ must be None if weight is None'
# Parameters.
if weight is None:
factory_kwargs = {'device': device, 'dtype': dtype}
self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs))
else:
weight.data = weight.data.to(device=device, dtype=dtype)
self.weight = weight
if not is_distributed_tensor(self.weight):
sharded_weight = shard_rowwise(self.weight.data, self.process_group)
sharded_tensor_to_existing_param(sharded_weight, self.weight)
if bias:
if bias_ is None:
self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs))
else:
bias_.data = bias_.data.to(device=device, dtype=dtype)
self.bias = bias_
if not is_distributed_tensor(self.bias):
sharded_bias = shard_colwise(self.bias.data, self.process_group)
sharded_tensor_to_existing_param(sharded_bias, self.bias)
else:
self.bias = None
if weight is None:
# init weights
self.reset_parameters(weight_initializer, bias_initializer)
@staticmethod
def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args,
@@ -125,17 +148,11 @@ class Linear1D_Col(ParallelModule):
bias=bias,
device=device,
process_group=process_group,
weight=module.weight,
bias_=module.bias,
*args,
**kwargs)
with torch.no_grad():
# the weight to the linear layer is a transpose
# thus shard on row is equal to shard on column
sharded_weight = shard_rowwise(module.weight.data, process_group)
linear_1d.weight.data.copy_(sharded_weight)
if bias:
sharded_bias = shard_colwise(module.bias.data, process_group)
linear_1d.bias.copy_(sharded_bias)
return linear_1d
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
@@ -198,6 +215,8 @@ class Linear1D_Row(ParallelModule):
process_group: ProcessGroup = None,
parallel_input: bool = True,
skip_bias_add: bool = False,
weight: Optional[Parameter] = None,
bias_: Optional[Parameter] = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
stream_chunk_num: int = 1):
@@ -216,27 +235,44 @@ class Linear1D_Row(ParallelModule):
if skip_bias_add and not bias:
raise ValueError('cannot skip bias addition if bias is None')
# Parameters.
# Initialize weight.
factory_kwargs = {'device': device, 'dtype': dtype}
weight = torch.empty(self.out_features, self.in_features, **factory_kwargs)
sharded_weight = shard_colwise(weight, self.process_group)
self.weight = sharded_tensor_to_param(sharded_weight)
if self.stream_chunk_num > 1:
# TODO() work for inference only
self.chunk_weight()
if bias:
self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs))
else:
self.bias = None
# offset the seed with randomizer index and rank
seed = torch.random.initial_seed()
self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
with self.randomizer.fork_rng(enable_cpu=True):
self.reset_parameters(weight_initializer, bias_initializer)
# sanity check
if weight is not None:
assert not bias or bias_ is not None, 'bias_ must be provided if bias is True when weight is not None'
else:
assert bias_ is None, 'bias_ must be None if weight is None'
# Parameters.
if weight is None:
# Initialize weight.
factory_kwargs = {'device': device, 'dtype': dtype}
self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs))
else:
weight.data = weight.data.to(device=device, dtype=dtype)
self.weight = weight
if not is_distributed_tensor(self.weight):
sharded_weight = shard_colwise(self.weight.data, self.process_group)
sharded_tensor_to_existing_param(sharded_weight, self.weight)
if self.stream_chunk_num > 1:
# TODO() work for inference only
self.chunk_weight()
if bias:
if bias_ is None:
self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs))
else:
bias_.data = bias_.data.to(device=device, dtype=dtype)
self.bias = bias_
else:
self.bias = None
if weight is None:
with self.randomizer.fork_rng(enable_cpu=True):
self.reset_parameters(weight_initializer, bias_initializer)
@staticmethod
def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args,
@@ -262,19 +298,11 @@ class Linear1D_Row(ParallelModule):
bias=bias,
device=device,
process_group=process_group,
weight=module.weight,
bias_=module.bias,
*args,
**kwargs)
# TODO: copy the sharded weights
with torch.no_grad():
# the weigh to the linear layer is a transpose
# thus shard on col is equal to shard on row
sharded_weight = shard_colwise(module.weight.data, process_group)
linear_1d.weight.data.copy_(sharded_weight)
if bias:
linear_1d.bias.copy_(module.bias.data)
return linear_1d
def chunk_weight(self):