[shardformer] support module saving and loading (#4062)

* [shardformer] support module saving and loading

* polish code
This commit is contained in:
Frank Lee
2023-06-22 11:42:11 +08:00
parent 7740c55c55
commit 8eb09a4c69
19 changed files with 493 additions and 102 deletions

View File

@@ -13,8 +13,7 @@ from torch.nn.parameter import Parameter
from colossalai.nn import init as init
from colossalai.nn.layer.utils import divide
from colossalai.tensor.d_tensor.api import shard_colwise, shard_rowwise
from colossalai.utils.cuda import get_current_device
from colossalai.tensor.d_tensor.api import shard_colwise, shard_rowwise, sharded_tensor_to_param
from ._operation import gather_forward_split_backward, reduce_input
from .parallel_module import ParallelModule
@@ -69,18 +68,17 @@ class Embedding1D(ParallelModule):
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.process_group = process_group
self.num_partitions = dist.get_world_size(process_group)
self.embed_dim_per_partition = divide(embedding_dim, self.num_partitions)
self.padding_idx = padding_idx
self.embed_args = args
self.embed_kwargs = kwargs
self.gather_output = gather_output
if device is None:
device = get_current_device()
self.weight = Parameter(torch.empty((num_embeddings, self.embed_dim_per_partition), device=device, dtype=dtype))
# Parameters.
factory_kwargs = {'device': device, 'dtype': dtype}
weight = torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs)
sharded_weight = shard_colwise(weight, process_group)
self.weight = sharded_tensor_to_param(sharded_weight)
# offset the seed with randomizer index and rank
seed = torch.random.initial_seed()
@@ -194,7 +192,7 @@ class VocabParallelEmbedding1D(ParallelModule):
**kwargs):
super().__init__()
self.num_embeddings = num_embeddings
self.embed_dim = embedding_dim
self.embedding_dim = embedding_dim
self.padding_idx = padding_idx
self.embed_args = args
self.embed_kwargs = kwargs
@@ -208,8 +206,11 @@ class VocabParallelEmbedding1D(ParallelModule):
self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition
self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition
self.weight = Parameter(
torch.empty((self.num_embeddings_per_partition, self.embed_dim), device=device, dtype=dtype))
# parameter
factory_kwargs = {'device': device, 'dtype': dtype}
weight = torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs)
sharded_weight = shard_rowwise(weight, process_group)
self.weight = sharded_tensor_to_param(sharded_weight)
# offset the seed with randomizer index and rank
seed = torch.random.initial_seed()
@@ -252,7 +253,7 @@ class VocabParallelEmbedding1D(ParallelModule):
def reset_parameters(self, weight_initializer) -> None:
with self.randomizer.fork_rng(enable_cpu=True):
fan_in, fan_out = self.num_embeddings, self.embed_dim
fan_in, fan_out = self.num_embeddings, self.embedding_dim
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
self._fill_padding_idx_with_zero()