[shardformer] support module saving and loading (#4062)

* [shardformer] support module saving and loading

* polish code
This commit is contained in:
Frank Lee
2023-06-22 11:42:11 +08:00
parent 7740c55c55
commit 8eb09a4c69
19 changed files with 493 additions and 102 deletions

View File

@@ -14,7 +14,7 @@ from torch.nn.parameter import Parameter
from colossalai.nn import init as init
from colossalai.nn.layer.utils import divide
from colossalai.tensor.d_tensor.api import shard_colwise, shard_rowwise
from colossalai.tensor.d_tensor import shard_colwise, shard_rowwise, sharded_tensor_to_param
from colossalai.utils.cuda import get_current_device
from ._operation import (
@@ -76,22 +76,21 @@ class Linear1D_Col(ParallelModule):
self.skip_bias_add = skip_bias_add
self.device = device
self.process_group = process_group
self.num_partitions = dist.get_world_size(self.process_group)
if skip_bias_add and not bias:
raise ValueError('cannot skip bias addition if bias is None')
self.out_features_per_partition = divide(out_features, self.num_partitions)
# Parameters.
# Initialize weight.
if device is None:
device = get_current_device()
factory_kwargs = {'device': device, 'dtype': dtype}
self.weight = Parameter(torch.empty(self.out_features_per_partition, self.in_features, **factory_kwargs))
weight = torch.empty(self.out_features, self.in_features, **factory_kwargs)
sharded_weight = shard_rowwise(weight, self.process_group)
self.weight = sharded_tensor_to_param(sharded_weight)
if bias:
self.bias = Parameter(torch.empty(self.out_features_per_partition, **factory_kwargs))
bias = torch.empty(self.out_features, **factory_kwargs)
sharded_bias = shard_colwise(bias, self.process_group)
self.bias = sharded_tensor_to_param(sharded_bias)
else:
self.bias = None
@@ -128,7 +127,6 @@ class Linear1D_Col(ParallelModule):
*args,
**kwargs)
# TODO: copy the sharded weights
with torch.no_grad():
# the weigh to the linear layer is a transpose
# thus shard on row is equal to shard on column
@@ -137,7 +135,6 @@ class Linear1D_Col(ParallelModule):
if bias:
sharded_bias = shard_colwise(module.bias.data, process_group)
linear_1d.bias.copy_(sharded_bias)
return linear_1d
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
@@ -212,21 +209,20 @@ class Linear1D_Row(ParallelModule):
self.parallel_input = parallel_input
self.skip_bias_add = skip_bias_add
self.process_group = process_group
self.num_partitions = dist.get_world_size(self.process_group)
if skip_bias_add and not bias:
raise ValueError('cannot skip bias addition if bias is None')
# Divide the weight matrix along the last dimension.
self.input_size_per_partition = divide(in_features, self.num_partitions)
# Parameters.
# Initialize weight.
if device is None:
device = get_current_device()
factory_kwargs = {'device': device, 'dtype': dtype}
self.weight = Parameter(torch.empty(self.out_features, self.input_size_per_partition, **factory_kwargs))
weight = torch.empty(self.out_features, self.in_features, **factory_kwargs)
sharded_weight = shard_colwise(weight, self.process_group)
self.weight = sharded_tensor_to_param(sharded_weight)
if self.stream_chunk_num > 1:
# TODO() work for inference only
@@ -340,3 +336,5 @@ class Linear1D_Row(ParallelModule):
return output
else:
return output, self.bias
return output, self.bias
return output, self.bias