mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 12:30:42 +00:00
[shardformer] support module saving and loading (#4062)
* [shardformer] support module saving and loading * polish code
This commit is contained in:
@@ -14,7 +14,7 @@ from torch.nn.parameter import Parameter
|
||||
|
||||
from colossalai.nn import init as init
|
||||
from colossalai.nn.layer.utils import divide
|
||||
from colossalai.tensor.d_tensor.api import shard_colwise, shard_rowwise
|
||||
from colossalai.tensor.d_tensor import shard_colwise, shard_rowwise, sharded_tensor_to_param
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
|
||||
from ._operation import (
|
||||
@@ -76,22 +76,21 @@ class Linear1D_Col(ParallelModule):
|
||||
self.skip_bias_add = skip_bias_add
|
||||
self.device = device
|
||||
self.process_group = process_group
|
||||
self.num_partitions = dist.get_world_size(self.process_group)
|
||||
|
||||
if skip_bias_add and not bias:
|
||||
raise ValueError('cannot skip bias addition if bias is None')
|
||||
|
||||
self.out_features_per_partition = divide(out_features, self.num_partitions)
|
||||
|
||||
# Parameters.
|
||||
# Initialize weight.
|
||||
if device is None:
|
||||
device = get_current_device()
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
self.weight = Parameter(torch.empty(self.out_features_per_partition, self.in_features, **factory_kwargs))
|
||||
|
||||
weight = torch.empty(self.out_features, self.in_features, **factory_kwargs)
|
||||
sharded_weight = shard_rowwise(weight, self.process_group)
|
||||
self.weight = sharded_tensor_to_param(sharded_weight)
|
||||
|
||||
if bias:
|
||||
self.bias = Parameter(torch.empty(self.out_features_per_partition, **factory_kwargs))
|
||||
bias = torch.empty(self.out_features, **factory_kwargs)
|
||||
sharded_bias = shard_colwise(bias, self.process_group)
|
||||
self.bias = sharded_tensor_to_param(sharded_bias)
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
@@ -128,7 +127,6 @@ class Linear1D_Col(ParallelModule):
|
||||
*args,
|
||||
**kwargs)
|
||||
|
||||
# TODO: copy the sharded weights
|
||||
with torch.no_grad():
|
||||
# the weigh to the linear layer is a transpose
|
||||
# thus shard on row is equal to shard on column
|
||||
@@ -137,7 +135,6 @@ class Linear1D_Col(ParallelModule):
|
||||
if bias:
|
||||
sharded_bias = shard_colwise(module.bias.data, process_group)
|
||||
linear_1d.bias.copy_(sharded_bias)
|
||||
|
||||
return linear_1d
|
||||
|
||||
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
|
||||
@@ -212,21 +209,20 @@ class Linear1D_Row(ParallelModule):
|
||||
self.parallel_input = parallel_input
|
||||
self.skip_bias_add = skip_bias_add
|
||||
self.process_group = process_group
|
||||
self.num_partitions = dist.get_world_size(self.process_group)
|
||||
|
||||
if skip_bias_add and not bias:
|
||||
raise ValueError('cannot skip bias addition if bias is None')
|
||||
|
||||
# Divide the weight matrix along the last dimension.
|
||||
self.input_size_per_partition = divide(in_features, self.num_partitions)
|
||||
|
||||
# Parameters.
|
||||
# Initialize weight.
|
||||
if device is None:
|
||||
device = get_current_device()
|
||||
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
self.weight = Parameter(torch.empty(self.out_features, self.input_size_per_partition, **factory_kwargs))
|
||||
|
||||
weight = torch.empty(self.out_features, self.in_features, **factory_kwargs)
|
||||
sharded_weight = shard_colwise(weight, self.process_group)
|
||||
self.weight = sharded_tensor_to_param(sharded_weight)
|
||||
|
||||
if self.stream_chunk_num > 1:
|
||||
# TODO() work for inference only
|
||||
@@ -340,3 +336,5 @@ class Linear1D_Row(ParallelModule):
|
||||
return output
|
||||
else:
|
||||
return output, self.bias
|
||||
return output, self.bias
|
||||
return output, self.bias
|
||||
|
Reference in New Issue
Block a user