mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 20:10:17 +00:00
[shardformer] support inplace sharding (#4251)
* [shardformer] embedding support inplace sharding * [shardformer] linear support inplace sharding * [shardformer] layernorm support inplace sharding * [shardformer] qkv support inplace sharding * [test] update shardformer layer test * [shardformer] fix shared param sharding * [shardformer] fix bert policy * [shardformer] fix bloom policy * [shardformer] fix llama policy * [shardformer] fix opt policy * [shardformer] fix t5 policy * [shardformer] fix fused qkv linear * [shardformer] fix bugs * force sync * [test] fix bugs * [test] fix transformer version
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import math
|
||||
from typing import Callable, List, Tuple, Union
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@@ -15,7 +15,12 @@ from torch.nn.parameter import Parameter
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.nn import init as init
|
||||
from colossalai.nn.layer.utils import divide
|
||||
from colossalai.tensor.d_tensor import shard_colwise, shard_rowwise, sharded_tensor_to_param
|
||||
from colossalai.tensor.d_tensor.api import (
|
||||
is_distributed_tensor,
|
||||
shard_colwise,
|
||||
shard_rowwise,
|
||||
sharded_tensor_to_existing_param,
|
||||
)
|
||||
|
||||
from ._operation import (
|
||||
gather_forward_split_backward,
|
||||
@@ -65,6 +70,8 @@ class Linear1D_Col(ParallelModule):
|
||||
process_group: ProcessGroup = None,
|
||||
gather_output: bool = False,
|
||||
skip_bias_add: bool = False,
|
||||
weight: Optional[Parameter] = None,
|
||||
bias_: Optional[Parameter] = None,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
|
||||
super().__init__()
|
||||
@@ -80,26 +87,42 @@ class Linear1D_Col(ParallelModule):
|
||||
if skip_bias_add and not bias:
|
||||
raise ValueError('cannot skip bias addition if bias is None')
|
||||
|
||||
# Parameters.
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
|
||||
weight = torch.empty(self.out_features, self.in_features, **factory_kwargs)
|
||||
sharded_weight = shard_rowwise(weight, self.process_group)
|
||||
self.weight = sharded_tensor_to_param(sharded_weight)
|
||||
|
||||
if bias:
|
||||
bias = torch.empty(self.out_features, **factory_kwargs)
|
||||
sharded_bias = shard_colwise(bias, self.process_group)
|
||||
self.bias = sharded_tensor_to_param(sharded_bias)
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
# offset the seed with randomizer index and rank
|
||||
seed = torch.random.initial_seed()
|
||||
self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
|
||||
|
||||
# init weights
|
||||
self.reset_parameters(weight_initializer, bias_initializer)
|
||||
# sanity check
|
||||
if weight is not None:
|
||||
assert not bias or bias_ is not None, 'bias_ must be provided if bias is True when weight is not None'
|
||||
else:
|
||||
assert bias_ is None, 'bias_ must be None if weight is None'
|
||||
|
||||
# Parameters.
|
||||
if weight is None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs))
|
||||
else:
|
||||
weight.data = weight.data.to(device=device, dtype=dtype)
|
||||
self.weight = weight
|
||||
if not is_distributed_tensor(self.weight):
|
||||
sharded_weight = shard_rowwise(self.weight.data, self.process_group)
|
||||
sharded_tensor_to_existing_param(sharded_weight, self.weight)
|
||||
|
||||
if bias:
|
||||
if bias_ is None:
|
||||
self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs))
|
||||
else:
|
||||
bias_.data = bias_.data.to(device=device, dtype=dtype)
|
||||
self.bias = bias_
|
||||
if not is_distributed_tensor(self.bias):
|
||||
sharded_bias = shard_colwise(self.bias.data, self.process_group)
|
||||
sharded_tensor_to_existing_param(sharded_bias, self.bias)
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
if weight is None:
|
||||
# init weights
|
||||
self.reset_parameters(weight_initializer, bias_initializer)
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args,
|
||||
@@ -125,17 +148,11 @@ class Linear1D_Col(ParallelModule):
|
||||
bias=bias,
|
||||
device=device,
|
||||
process_group=process_group,
|
||||
weight=module.weight,
|
||||
bias_=module.bias,
|
||||
*args,
|
||||
**kwargs)
|
||||
|
||||
with torch.no_grad():
|
||||
# the weight to the linear layer is a transpose
|
||||
# thus shard on row is equal to shard on column
|
||||
sharded_weight = shard_rowwise(module.weight.data, process_group)
|
||||
linear_1d.weight.data.copy_(sharded_weight)
|
||||
if bias:
|
||||
sharded_bias = shard_colwise(module.bias.data, process_group)
|
||||
linear_1d.bias.copy_(sharded_bias)
|
||||
return linear_1d
|
||||
|
||||
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
|
||||
@@ -198,6 +215,8 @@ class Linear1D_Row(ParallelModule):
|
||||
process_group: ProcessGroup = None,
|
||||
parallel_input: bool = True,
|
||||
skip_bias_add: bool = False,
|
||||
weight: Optional[Parameter] = None,
|
||||
bias_: Optional[Parameter] = None,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
stream_chunk_num: int = 1):
|
||||
@@ -216,27 +235,44 @@ class Linear1D_Row(ParallelModule):
|
||||
if skip_bias_add and not bias:
|
||||
raise ValueError('cannot skip bias addition if bias is None')
|
||||
|
||||
# Parameters.
|
||||
# Initialize weight.
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
weight = torch.empty(self.out_features, self.in_features, **factory_kwargs)
|
||||
sharded_weight = shard_colwise(weight, self.process_group)
|
||||
self.weight = sharded_tensor_to_param(sharded_weight)
|
||||
|
||||
if self.stream_chunk_num > 1:
|
||||
# TODO() work for inference only
|
||||
self.chunk_weight()
|
||||
if bias:
|
||||
self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs))
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
# offset the seed with randomizer index and rank
|
||||
seed = torch.random.initial_seed()
|
||||
self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
|
||||
|
||||
with self.randomizer.fork_rng(enable_cpu=True):
|
||||
self.reset_parameters(weight_initializer, bias_initializer)
|
||||
# sanity check
|
||||
if weight is not None:
|
||||
assert not bias or bias_ is not None, 'bias_ must be provided if bias is True when weight is not None'
|
||||
else:
|
||||
assert bias_ is None, 'bias_ must be None if weight is None'
|
||||
|
||||
# Parameters.
|
||||
if weight is None:
|
||||
# Initialize weight.
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs))
|
||||
else:
|
||||
weight.data = weight.data.to(device=device, dtype=dtype)
|
||||
self.weight = weight
|
||||
if not is_distributed_tensor(self.weight):
|
||||
sharded_weight = shard_colwise(self.weight.data, self.process_group)
|
||||
sharded_tensor_to_existing_param(sharded_weight, self.weight)
|
||||
|
||||
if self.stream_chunk_num > 1:
|
||||
# TODO() work for inference only
|
||||
self.chunk_weight()
|
||||
|
||||
if bias:
|
||||
if bias_ is None:
|
||||
self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs))
|
||||
else:
|
||||
bias_.data = bias_.data.to(device=device, dtype=dtype)
|
||||
self.bias = bias_
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
if weight is None:
|
||||
with self.randomizer.fork_rng(enable_cpu=True):
|
||||
self.reset_parameters(weight_initializer, bias_initializer)
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args,
|
||||
@@ -262,19 +298,11 @@ class Linear1D_Row(ParallelModule):
|
||||
bias=bias,
|
||||
device=device,
|
||||
process_group=process_group,
|
||||
weight=module.weight,
|
||||
bias_=module.bias,
|
||||
*args,
|
||||
**kwargs)
|
||||
|
||||
# TODO: copy the sharded weights
|
||||
with torch.no_grad():
|
||||
# the weigh to the linear layer is a transpose
|
||||
# thus shard on col is equal to shard on row
|
||||
sharded_weight = shard_colwise(module.weight.data, process_group)
|
||||
linear_1d.weight.data.copy_(sharded_weight)
|
||||
|
||||
if bias:
|
||||
linear_1d.bias.copy_(module.bias.data)
|
||||
|
||||
return linear_1d
|
||||
|
||||
def chunk_weight(self):
|
||||
|
Reference in New Issue
Block a user