[shardformer] support inplace sharding (#4251)

* [shardformer] embedding support inplace sharding

* [shardformer] linear support inplace sharding

* [shardformer] layernorm support inplace sharding

* [shardformer] qkv support inplace sharding

* [test] update shardformer layer test

* [shardformer] fix shared param sharding

* [shardformer] fix bert policy

* [shardformer] fix bloom policy

* [shardformer] fix llama policy

* [shardformer] fix opt policy

* [shardformer] fix t5 policy

* [shardformer] fix fused qkv linear

* [shardformer] fix bugs

* force sync

* [test] fix bugs

* [test] fix transformer version
This commit is contained in:
Hongxin Liu
2023-07-20 10:39:06 +08:00
parent 2a2eacfaf1
commit d921ce8391
26 changed files with 371 additions and 340 deletions

View File

@@ -1,7 +1,7 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import Callable, List, Union
from typing import Callable, List, Optional, Union
import torch
import torch.distributed as dist
@@ -13,7 +13,12 @@ from torch.distributed import ProcessGroup
from colossalai.lazy import LazyInitContext
from colossalai.nn import init as init
from colossalai.nn.layer.utils import divide
from colossalai.tensor.d_tensor.api import shard_colwise, shard_rowwise, sharded_tensor_to_param
from colossalai.tensor.d_tensor.api import (
is_distributed_tensor,
shard_colwise,
shard_rowwise,
sharded_tensor_to_existing_param,
)
from ._operation import gather_forward_split_backward, reduce_forward
from .parallel_module import ParallelModule
@@ -60,6 +65,7 @@ class Embedding1D(ParallelModule):
device: torch.device = None,
process_group: ProcessGroup = None,
gather_output: bool = True,
weight: Optional[nn.Parameter] = None,
weight_initializer: Callable = init.normal_(),
*args,
**kwargs):
@@ -74,18 +80,24 @@ class Embedding1D(ParallelModule):
self.embed_kwargs = kwargs
self.gather_output = gather_output
# Parameters.
factory_kwargs = {'device': device, 'dtype': dtype}
weight = torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs)
sharded_weight = shard_colwise(weight, process_group)
self.weight = sharded_tensor_to_param(sharded_weight)
# offset the seed with randomizer index and rank
seed = torch.random.initial_seed()
self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
with self.randomizer.fork_rng(enable_cpu=True):
self.reset_parameters(weight_initializer)
# Parameters.
if weight is None:
factory_kwargs = {'device': device, 'dtype': dtype}
self.weight = nn.Parameter(torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs))
else:
weight.data = weight.data.to(device=device, dtype=dtype)
self.weight = weight
if not is_distributed_tensor(self.weight):
sharded_weight = shard_colwise(self.weight.data, process_group)
sharded_tensor_to_existing_param(sharded_weight, self.weight)
if weight is None:
with self.randomizer.fork_rng(enable_cpu=True):
self.reset_parameters(weight_initializer)
@staticmethod
def from_native_module(module: nn.Embedding,
@@ -121,14 +133,10 @@ class Embedding1D(ParallelModule):
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse,
weight=module.weight,
*args,
**kwargs)
# copy the weight
with torch.no_grad():
sharded_weight = shard_colwise(module.weight.data, process_group)
embedding.weight.copy_(sharded_weight)
return embedding
def reset_parameters(self, weight_initializer) -> None:
@@ -143,7 +151,6 @@ class Embedding1D(ParallelModule):
def forward(self, input_: Tensor) -> Tensor:
output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
if self.gather_output:
output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group)
return output
@@ -188,6 +195,7 @@ class VocabParallelEmbedding1D(ParallelModule):
dtype: torch.dtype = None,
device: torch.device = None,
process_group: ProcessGroup = None,
weight: Optional[nn.Parameter] = None,
weight_initializer: Callable = init.normal_(),
*args,
**kwargs):
@@ -207,16 +215,23 @@ class VocabParallelEmbedding1D(ParallelModule):
self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition
self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition
# parameter
factory_kwargs = {'device': device, 'dtype': dtype}
weight = torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs)
sharded_weight = shard_rowwise(weight, process_group)
self.weight = sharded_tensor_to_param(sharded_weight)
# offset the seed with randomizer index and rank
seed = torch.random.initial_seed()
self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
self.reset_parameters(weight_initializer)
# parameter
if weight is None:
factory_kwargs = {'device': device, 'dtype': dtype}
self.weight = nn.Parameter(torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs))
else:
weight.data = weight.data.to(device=device, dtype=dtype)
self.weight = weight
if not is_distributed_tensor(self.weight):
sharded_weight = shard_rowwise(self.weight.data, process_group)
sharded_tensor_to_existing_param(sharded_weight, self.weight)
if weight is None:
self.reset_parameters(weight_initializer)
@staticmethod
def from_native_module(module: nn.Embedding, process_group: Union[ProcessGroup, List[ProcessGroup]], *args,
@@ -243,15 +258,10 @@ class VocabParallelEmbedding1D(ParallelModule):
padding_idx=padding_idx,
device=device,
process_group=process_group,
weight=module.weight,
*args,
**kwargs)
with torch.no_grad():
# shard and slice the weight along the vocabulary(num_embeddings) dimension
# the shape of the weight is (num_embeddings, embedding_dim)
shard_weight = shard_rowwise(module.weight.data, process_group)
vocab_embedding_1d.weight.data.copy_(shard_weight)
return vocab_embedding_1d
def reset_parameters(self, weight_initializer) -> None: