mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 11:32:10 +00:00
[shardformer] support inplace sharding (#4251)
* [shardformer] embedding support inplace sharding * [shardformer] linear support inplace sharding * [shardformer] layernorm support inplace sharding * [shardformer] qkv support inplace sharding * [test] update shardformer layer test * [shardformer] fix shared param sharding * [shardformer] fix bert policy * [shardformer] fix bloom policy * [shardformer] fix llama policy * [shardformer] fix opt policy * [shardformer] fix t5 policy * [shardformer] fix fused qkv linear * [shardformer] fix bugs * force sync * [test] fix bugs * [test] fix transformer version
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from typing import Callable, List, Union
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@@ -13,7 +13,12 @@ from torch.distributed import ProcessGroup
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.nn import init as init
|
||||
from colossalai.nn.layer.utils import divide
|
||||
from colossalai.tensor.d_tensor.api import shard_colwise, shard_rowwise, sharded_tensor_to_param
|
||||
from colossalai.tensor.d_tensor.api import (
|
||||
is_distributed_tensor,
|
||||
shard_colwise,
|
||||
shard_rowwise,
|
||||
sharded_tensor_to_existing_param,
|
||||
)
|
||||
|
||||
from ._operation import gather_forward_split_backward, reduce_forward
|
||||
from .parallel_module import ParallelModule
|
||||
@@ -60,6 +65,7 @@ class Embedding1D(ParallelModule):
|
||||
device: torch.device = None,
|
||||
process_group: ProcessGroup = None,
|
||||
gather_output: bool = True,
|
||||
weight: Optional[nn.Parameter] = None,
|
||||
weight_initializer: Callable = init.normal_(),
|
||||
*args,
|
||||
**kwargs):
|
||||
@@ -74,18 +80,24 @@ class Embedding1D(ParallelModule):
|
||||
self.embed_kwargs = kwargs
|
||||
self.gather_output = gather_output
|
||||
|
||||
# Parameters.
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
weight = torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs)
|
||||
sharded_weight = shard_colwise(weight, process_group)
|
||||
self.weight = sharded_tensor_to_param(sharded_weight)
|
||||
|
||||
# offset the seed with randomizer index and rank
|
||||
seed = torch.random.initial_seed()
|
||||
self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
|
||||
|
||||
with self.randomizer.fork_rng(enable_cpu=True):
|
||||
self.reset_parameters(weight_initializer)
|
||||
# Parameters.
|
||||
if weight is None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
self.weight = nn.Parameter(torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs))
|
||||
else:
|
||||
weight.data = weight.data.to(device=device, dtype=dtype)
|
||||
self.weight = weight
|
||||
if not is_distributed_tensor(self.weight):
|
||||
sharded_weight = shard_colwise(self.weight.data, process_group)
|
||||
sharded_tensor_to_existing_param(sharded_weight, self.weight)
|
||||
|
||||
if weight is None:
|
||||
with self.randomizer.fork_rng(enable_cpu=True):
|
||||
self.reset_parameters(weight_initializer)
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(module: nn.Embedding,
|
||||
@@ -121,14 +133,10 @@ class Embedding1D(ParallelModule):
|
||||
norm_type=norm_type,
|
||||
scale_grad_by_freq=scale_grad_by_freq,
|
||||
sparse=sparse,
|
||||
weight=module.weight,
|
||||
*args,
|
||||
**kwargs)
|
||||
|
||||
# copy the weight
|
||||
with torch.no_grad():
|
||||
sharded_weight = shard_colwise(module.weight.data, process_group)
|
||||
embedding.weight.copy_(sharded_weight)
|
||||
|
||||
return embedding
|
||||
|
||||
def reset_parameters(self, weight_initializer) -> None:
|
||||
@@ -143,7 +151,6 @@ class Embedding1D(ParallelModule):
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
|
||||
|
||||
if self.gather_output:
|
||||
output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group)
|
||||
return output
|
||||
@@ -188,6 +195,7 @@ class VocabParallelEmbedding1D(ParallelModule):
|
||||
dtype: torch.dtype = None,
|
||||
device: torch.device = None,
|
||||
process_group: ProcessGroup = None,
|
||||
weight: Optional[nn.Parameter] = None,
|
||||
weight_initializer: Callable = init.normal_(),
|
||||
*args,
|
||||
**kwargs):
|
||||
@@ -207,16 +215,23 @@ class VocabParallelEmbedding1D(ParallelModule):
|
||||
self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition
|
||||
self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition
|
||||
|
||||
# parameter
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
weight = torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs)
|
||||
sharded_weight = shard_rowwise(weight, process_group)
|
||||
self.weight = sharded_tensor_to_param(sharded_weight)
|
||||
|
||||
# offset the seed with randomizer index and rank
|
||||
seed = torch.random.initial_seed()
|
||||
self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
|
||||
self.reset_parameters(weight_initializer)
|
||||
|
||||
# parameter
|
||||
if weight is None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
self.weight = nn.Parameter(torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs))
|
||||
else:
|
||||
weight.data = weight.data.to(device=device, dtype=dtype)
|
||||
self.weight = weight
|
||||
if not is_distributed_tensor(self.weight):
|
||||
sharded_weight = shard_rowwise(self.weight.data, process_group)
|
||||
sharded_tensor_to_existing_param(sharded_weight, self.weight)
|
||||
|
||||
if weight is None:
|
||||
self.reset_parameters(weight_initializer)
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(module: nn.Embedding, process_group: Union[ProcessGroup, List[ProcessGroup]], *args,
|
||||
@@ -243,15 +258,10 @@ class VocabParallelEmbedding1D(ParallelModule):
|
||||
padding_idx=padding_idx,
|
||||
device=device,
|
||||
process_group=process_group,
|
||||
weight=module.weight,
|
||||
*args,
|
||||
**kwargs)
|
||||
|
||||
with torch.no_grad():
|
||||
# shard and slice the weight along the vocabulary(num_embeddings) dimension
|
||||
# the shape of the weight is (num_embeddings, embedding_dim)
|
||||
shard_weight = shard_rowwise(module.weight.data, process_group)
|
||||
vocab_embedding_1d.weight.data.copy_(shard_weight)
|
||||
|
||||
return vocab_embedding_1d
|
||||
|
||||
def reset_parameters(self, weight_initializer) -> None:
|
||||
|
Reference in New Issue
Block a user