mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 11:32:10 +00:00
[shardformer] support module saving and loading (#4062)
* [shardformer] support module saving and loading * polish code
This commit is contained in:
@@ -13,8 +13,7 @@ from torch.nn.parameter import Parameter
|
||||
|
||||
from colossalai.nn import init as init
|
||||
from colossalai.nn.layer.utils import divide
|
||||
from colossalai.tensor.d_tensor.api import shard_colwise, shard_rowwise
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.tensor.d_tensor.api import shard_colwise, shard_rowwise, sharded_tensor_to_param
|
||||
|
||||
from ._operation import gather_forward_split_backward, reduce_input
|
||||
from .parallel_module import ParallelModule
|
||||
@@ -69,18 +68,17 @@ class Embedding1D(ParallelModule):
|
||||
self.num_embeddings = num_embeddings
|
||||
self.embedding_dim = embedding_dim
|
||||
self.process_group = process_group
|
||||
self.num_partitions = dist.get_world_size(process_group)
|
||||
self.embed_dim_per_partition = divide(embedding_dim, self.num_partitions)
|
||||
|
||||
self.padding_idx = padding_idx
|
||||
self.embed_args = args
|
||||
self.embed_kwargs = kwargs
|
||||
self.gather_output = gather_output
|
||||
|
||||
if device is None:
|
||||
device = get_current_device()
|
||||
|
||||
self.weight = Parameter(torch.empty((num_embeddings, self.embed_dim_per_partition), device=device, dtype=dtype))
|
||||
# Parameters.
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
weight = torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs)
|
||||
sharded_weight = shard_colwise(weight, process_group)
|
||||
self.weight = sharded_tensor_to_param(sharded_weight)
|
||||
|
||||
# offset the seed with randomizer index and rank
|
||||
seed = torch.random.initial_seed()
|
||||
@@ -194,7 +192,7 @@ class VocabParallelEmbedding1D(ParallelModule):
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
self.num_embeddings = num_embeddings
|
||||
self.embed_dim = embedding_dim
|
||||
self.embedding_dim = embedding_dim
|
||||
self.padding_idx = padding_idx
|
||||
self.embed_args = args
|
||||
self.embed_kwargs = kwargs
|
||||
@@ -208,8 +206,11 @@ class VocabParallelEmbedding1D(ParallelModule):
|
||||
self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition
|
||||
self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition
|
||||
|
||||
self.weight = Parameter(
|
||||
torch.empty((self.num_embeddings_per_partition, self.embed_dim), device=device, dtype=dtype))
|
||||
# parameter
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
weight = torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs)
|
||||
sharded_weight = shard_rowwise(weight, process_group)
|
||||
self.weight = sharded_tensor_to_param(sharded_weight)
|
||||
|
||||
# offset the seed with randomizer index and rank
|
||||
seed = torch.random.initial_seed()
|
||||
@@ -252,7 +253,7 @@ class VocabParallelEmbedding1D(ParallelModule):
|
||||
|
||||
def reset_parameters(self, weight_initializer) -> None:
|
||||
with self.randomizer.fork_rng(enable_cpu=True):
|
||||
fan_in, fan_out = self.num_embeddings, self.embed_dim
|
||||
fan_in, fan_out = self.num_embeddings, self.embedding_dim
|
||||
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
||||
self._fill_padding_idx_with_zero()
|
||||
|
||||
|
Reference in New Issue
Block a user