mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-21 02:40:48 +00:00
add vocabembedding layer
This commit is contained in:
parent
45d9384346
commit
507c0ad368
@ -139,6 +139,7 @@ class Linear1D_Col(ParallelModule):
|
|||||||
with self.randomizer.fork_rng(enable_cpu=True):
|
with self.randomizer.fork_rng(enable_cpu=True):
|
||||||
self.reset_parameters(weight_initializer, bias_initializer)
|
self.reset_parameters(weight_initializer, bias_initializer)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args,
|
def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args,
|
||||||
**kwargs) -> ParallelModule:
|
**kwargs) -> ParallelModule:
|
||||||
r"""
|
r"""
|
||||||
@ -587,6 +588,8 @@ class VocabParallelEmbedding1D(ParallelLayer):
|
|||||||
embedding_dim: int,
|
embedding_dim: int,
|
||||||
padding_idx: int = None,
|
padding_idx: int = None,
|
||||||
dtype: torch.dtype = None,
|
dtype: torch.dtype = None,
|
||||||
|
device: torch.device = None,
|
||||||
|
process_group: ProcessGroup = None,
|
||||||
weight_initializer: Callable = init.normal_(),
|
weight_initializer: Callable = init.normal_(),
|
||||||
*args,
|
*args,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
@ -596,21 +599,63 @@ class VocabParallelEmbedding1D(ParallelLayer):
|
|||||||
self.padding_idx = padding_idx
|
self.padding_idx = padding_idx
|
||||||
self.embed_args = args
|
self.embed_args = args
|
||||||
self.embed_kwargs = kwargs
|
self.embed_kwargs = kwargs
|
||||||
|
self.process_group = process_group
|
||||||
|
|
||||||
tensor_parallel_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
|
tensor_parallel_size = dist.get_world_size(group=process_group)
|
||||||
tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
tensor_parallel_rank = dist.get_rank(group=process_group)
|
||||||
# self.num_embeddings_per_partition = divide(num_embeddings, tensor_parallel_size)
|
|
||||||
self.num_embeddings_per_partition = num_embeddings
|
self.num_embeddings_per_partition = divide(num_embeddings, tensor_parallel_size)
|
||||||
|
self.num_embeddings = self.num_embeddings_per_partition
|
||||||
self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition
|
self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition
|
||||||
self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition
|
self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition
|
||||||
|
|
||||||
self.weight = Parameter(
|
self.weight = Parameter(
|
||||||
torch.empty((self.num_embeddings_per_partition, self.embed_dim), device=get_current_device(), dtype=dtype))
|
torch.empty((self.num_embeddings_per_partition, self.embed_dim), device=device, dtype=dtype))
|
||||||
|
|
||||||
self.reset_parameters(weight_initializer)
|
# offset the seed with randomizer index and rank
|
||||||
self._set_tensor_parallel_attributes()
|
seed = torch.random.initial_seed()
|
||||||
set_parallel_input(False)
|
self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
|
||||||
env.vocab_parallel = True
|
|
||||||
|
with self.randomizer.fork_rng(enable_cpu=True):
|
||||||
|
self.reset_parameters(weight_initializer)
|
||||||
|
# self.reset_parameters(weight_initializer)
|
||||||
|
# self._set_tensor_parallel_attributes()
|
||||||
|
# set_parallel_input(False)
|
||||||
|
# env.vocab_parallel = True
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_native_module(module: nn.Embedding, process_group: Union[ProcessGroup, List[ProcessGroup]], *args,
|
||||||
|
**kwargs) -> ParallelModule:
|
||||||
|
r"""
|
||||||
|
Convert a native pytorch embedding module to a parallel module.
|
||||||
|
"""
|
||||||
|
# get the origin attributes
|
||||||
|
num_embeddings = module.num_embeddings
|
||||||
|
embedding_dim = module.embedding_dim
|
||||||
|
padding_idx = module.padding_idx
|
||||||
|
device = module.weight.device
|
||||||
|
|
||||||
|
# ensure only one process group is used
|
||||||
|
if isinstance(process_group, (list, tuple)):
|
||||||
|
assert len(process_group) == 1, \
|
||||||
|
f'Expected only one process group, got {len(process_group)}.'
|
||||||
|
process_group = process_group[0]
|
||||||
|
|
||||||
|
# create the parallel module
|
||||||
|
vocab_embedding_1d = VocabParallelEmbedding1D(num_embeddings=num_embeddings,
|
||||||
|
embedding_dim=embedding_dim,
|
||||||
|
padding_idx=padding_idx,
|
||||||
|
device=device,
|
||||||
|
process_group=process_group,
|
||||||
|
*args,
|
||||||
|
**kwargs)
|
||||||
|
with torch.no_grad():
|
||||||
|
# shard and slice the weight along the vocabulary(num_embeddings) dimension
|
||||||
|
# the shape of the weight is (num_embeddings, embedding_dim)
|
||||||
|
shard_weight = shard_rowwise(module.weight.data, process_group)
|
||||||
|
vocab_embedding_1d.weight.data.copy_(shard_weight)
|
||||||
|
|
||||||
|
return vocab_embedding_1d
|
||||||
|
|
||||||
def _set_tensor_parallel_attributes(self):
|
def _set_tensor_parallel_attributes(self):
|
||||||
set_tensor_parallel_attribute_by_partition(self.weight, gpc.tensor_parallel_size)
|
set_tensor_parallel_attribute_by_partition(self.weight, gpc.tensor_parallel_size)
|
||||||
@ -665,5 +710,5 @@ class VocabParallelEmbedding1D(ParallelLayer):
|
|||||||
# Mask the output embedding.
|
# Mask the output embedding.
|
||||||
output_parallel[input_mask, :] = 0.
|
output_parallel[input_mask, :] = 0.
|
||||||
# Reduce across all the model parallel GPUs.
|
# Reduce across all the model parallel GPUs.
|
||||||
output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D)
|
output = reduce_input(output_parallel, self.process_group)
|
||||||
return output
|
return output
|
||||||
|
@ -0,0 +1,45 @@
|
|||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.testing import assert_close
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.shardformer.layer.layers import VocabParallelEmbedding1D
|
||||||
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||||
|
|
||||||
|
|
||||||
|
def check_vocab_embedding_1d():
|
||||||
|
embedding = nn.Embedding(128, 32).to('cuda')
|
||||||
|
dist_embedding_1d = VocabParallelEmbedding1D.from_native_module(embedding, process_group=None)
|
||||||
|
|
||||||
|
assert dist_embedding_1d.weight.shape == torch.Size([64, 32])
|
||||||
|
assert dist_embedding_1d.num_embeddings == 64
|
||||||
|
assert dist_embedding_1d.embed_dim == 32
|
||||||
|
|
||||||
|
# check embedding correctness
|
||||||
|
x = torch.randint(0, 128, (4, 32)).to('cuda')
|
||||||
|
org_out = embedding(x)
|
||||||
|
dist_out = dist_embedding_1d(x)
|
||||||
|
assert_close(org_out, dist_out)
|
||||||
|
|
||||||
|
# check backward correctness
|
||||||
|
org_out.sum().backward()
|
||||||
|
dist_out.sum().backward()
|
||||||
|
|
||||||
|
rank = dist.get_rank()
|
||||||
|
target_grad = torch.chunk(embedding.weight.grad, 2, dim=0)[rank]
|
||||||
|
assert_close(target_grad, dist_embedding_1d.weight.grad)
|
||||||
|
|
||||||
|
|
||||||
|
def run_dist(rank, world_size, port):
|
||||||
|
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
|
check_vocab_embedding_1d()
|
||||||
|
|
||||||
|
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
def test_vocab_embedding():
|
||||||
|
spawn(run_dist, nprocs=2)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_vocab_embedding()
|
Loading…
Reference in New Issue
Block a user