From 507c0ad368dd8016f3faa4147ff5ce0b7e3ae0c6 Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Fri, 16 Jun 2023 15:04:07 +0800 Subject: [PATCH] add vocabembedding layer --- colossalai/shardformer/layer/layers.py | 65 ++++++++++++++++--- .../test_vocab_parallel_embedding_1d.py | 45 +++++++++++++ 2 files changed, 100 insertions(+), 10 deletions(-) create mode 100644 tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py diff --git a/colossalai/shardformer/layer/layers.py b/colossalai/shardformer/layer/layers.py index 586aec124..ad6e1896a 100644 --- a/colossalai/shardformer/layer/layers.py +++ b/colossalai/shardformer/layer/layers.py @@ -139,6 +139,7 @@ class Linear1D_Col(ParallelModule): with self.randomizer.fork_rng(enable_cpu=True): self.reset_parameters(weight_initializer, bias_initializer) + @staticmethod def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs) -> ParallelModule: r""" @@ -587,6 +588,8 @@ class VocabParallelEmbedding1D(ParallelLayer): embedding_dim: int, padding_idx: int = None, dtype: torch.dtype = None, + device: torch.device = None, + process_group: ProcessGroup = None, weight_initializer: Callable = init.normal_(), *args, **kwargs): @@ -596,21 +599,63 @@ class VocabParallelEmbedding1D(ParallelLayer): self.padding_idx = padding_idx self.embed_args = args self.embed_kwargs = kwargs + self.process_group = process_group - tensor_parallel_size = gpc.get_world_size(ParallelMode.PARALLEL_1D) - tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) - # self.num_embeddings_per_partition = divide(num_embeddings, tensor_parallel_size) - self.num_embeddings_per_partition = num_embeddings + tensor_parallel_size = dist.get_world_size(group=process_group) + tensor_parallel_rank = dist.get_rank(group=process_group) + + self.num_embeddings_per_partition = divide(num_embeddings, tensor_parallel_size) + self.num_embeddings = self.num_embeddings_per_partition self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition self.weight = Parameter( - torch.empty((self.num_embeddings_per_partition, self.embed_dim), device=get_current_device(), dtype=dtype)) + torch.empty((self.num_embeddings_per_partition, self.embed_dim), device=device, dtype=dtype)) - self.reset_parameters(weight_initializer) - self._set_tensor_parallel_attributes() - set_parallel_input(False) - env.vocab_parallel = True + # offset the seed with randomizer index and rank + seed = torch.random.initial_seed() + self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) + + with self.randomizer.fork_rng(enable_cpu=True): + self.reset_parameters(weight_initializer) + # self.reset_parameters(weight_initializer) + # self._set_tensor_parallel_attributes() + # set_parallel_input(False) + # env.vocab_parallel = True + + @staticmethod + def from_native_module(module: nn.Embedding, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, + **kwargs) -> ParallelModule: + r""" + Convert a native pytorch embedding module to a parallel module. + """ + # get the origin attributes + num_embeddings = module.num_embeddings + embedding_dim = module.embedding_dim + padding_idx = module.padding_idx + device = module.weight.device + + # ensure only one process group is used + if isinstance(process_group, (list, tuple)): + assert len(process_group) == 1, \ + f'Expected only one process group, got {len(process_group)}.' + process_group = process_group[0] + + # create the parallel module + vocab_embedding_1d = VocabParallelEmbedding1D(num_embeddings=num_embeddings, + embedding_dim=embedding_dim, + padding_idx=padding_idx, + device=device, + process_group=process_group, + *args, + **kwargs) + with torch.no_grad(): + # shard and slice the weight along the vocabulary(num_embeddings) dimension + # the shape of the weight is (num_embeddings, embedding_dim) + shard_weight = shard_rowwise(module.weight.data, process_group) + vocab_embedding_1d.weight.data.copy_(shard_weight) + + return vocab_embedding_1d def _set_tensor_parallel_attributes(self): set_tensor_parallel_attribute_by_partition(self.weight, gpc.tensor_parallel_size) @@ -665,5 +710,5 @@ class VocabParallelEmbedding1D(ParallelLayer): # Mask the output embedding. output_parallel[input_mask, :] = 0. # Reduce across all the model parallel GPUs. - output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D) + output = reduce_input(output_parallel, self.process_group) return output diff --git a/tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py b/tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py new file mode 100644 index 000000000..3df53e8a8 --- /dev/null +++ b/tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py @@ -0,0 +1,45 @@ +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.testing import assert_close + +import colossalai +from colossalai.shardformer.layer.layers import VocabParallelEmbedding1D +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn + + +def check_vocab_embedding_1d(): + embedding = nn.Embedding(128, 32).to('cuda') + dist_embedding_1d = VocabParallelEmbedding1D.from_native_module(embedding, process_group=None) + + assert dist_embedding_1d.weight.shape == torch.Size([64, 32]) + assert dist_embedding_1d.num_embeddings == 64 + assert dist_embedding_1d.embed_dim == 32 + + # check embedding correctness + x = torch.randint(0, 128, (4, 32)).to('cuda') + org_out = embedding(x) + dist_out = dist_embedding_1d(x) + assert_close(org_out, dist_out) + + # check backward correctness + org_out.sum().backward() + dist_out.sum().backward() + + rank = dist.get_rank() + target_grad = torch.chunk(embedding.weight.grad, 2, dim=0)[rank] + assert_close(target_grad, dist_embedding_1d.weight.grad) + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + check_vocab_embedding_1d() + + +@rerun_if_address_is_in_use() +def test_vocab_embedding(): + spawn(run_dist, nprocs=2) + + +if __name__ == '__main__': + test_vocab_embedding()