mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-14 14:13:22 +00:00
[shardformer] support lazy init (#4202)
* [shardformer] support lazy init * [shardformer] linear support lazy init * [shardformer] embedding support lazy init * [shardformer] norm support lazy init * [shardformer] fused linear support lazy init * [test] update shardformer test layer * [test] shardformer with lazy init fit ddp * [lazy] hotfix deepcopy of param * [shardformer] fix bert policy and update test * [shardformer] fix bloom policy and update test * [shardformer] fix opt policy and update test * [shardformer] fix t5 policy and update test * [shardformer] fix gpt2 policy and update test * [shardformer] fix llama policy and update test
This commit is contained in:
parent
c6f9c2c033
commit
0192011688
@ -5,6 +5,7 @@ import torch
|
|||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
from torch.nn import Parameter
|
||||||
from torch.utils._pytree import tree_map
|
from torch.utils._pytree import tree_map
|
||||||
|
|
||||||
from colossalai._analyzer._subclasses import MetaTensor
|
from colossalai._analyzer._subclasses import MetaTensor
|
||||||
@ -95,8 +96,11 @@ def _convert_cls(tensor: 'LazyTensor', target: torch.Tensor) -> torch.Tensor:
|
|||||||
Returns:
|
Returns:
|
||||||
torch.Tensor: the converted tensor
|
torch.Tensor: the converted tensor
|
||||||
"""
|
"""
|
||||||
cls_to_become = nn.Parameter if isinstance(tensor, nn.Parameter) else torch.Tensor
|
cls_to_become = Parameter if isinstance(tensor, Parameter) else torch.Tensor
|
||||||
tensor.__class__ = cls_to_become
|
tensor.__class__ = cls_to_become
|
||||||
|
if cls_to_become is Parameter:
|
||||||
|
# to fit UninitializedParameter
|
||||||
|
delattr(tensor, '_is_param')
|
||||||
tensor.data = target
|
tensor.data = target
|
||||||
tensor.requires_grad = target.requires_grad
|
tensor.requires_grad = target.requires_grad
|
||||||
# subclass of torch.Tensor does not have tolist() method
|
# subclass of torch.Tensor does not have tolist() method
|
||||||
@ -190,10 +194,10 @@ class LazyTensor(torch.Tensor):
|
|||||||
def clean(self) -> None:
|
def clean(self) -> None:
|
||||||
"""Clean all stored operations, meta data and materialized data, which prevents memory leaking. This should be called after all tensors are materialized.
|
"""Clean all stored operations, meta data and materialized data, which prevents memory leaking. This should be called after all tensors are materialized.
|
||||||
"""
|
"""
|
||||||
self._factory_method = None
|
delattr(self, '_factory_method')
|
||||||
self._op_buffer = None
|
delattr(self, '_op_buffer')
|
||||||
self._materialized_data = None
|
delattr(self, '_materialized_data')
|
||||||
self._meta_data = None
|
delattr(self, '_meta_data')
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _replace_with_materialized(x):
|
def _replace_with_materialized(x):
|
||||||
@ -346,20 +350,19 @@ class LazyTensor(torch.Tensor):
|
|||||||
def factory_fn():
|
def factory_fn():
|
||||||
# if self is materialized, return self
|
# if self is materialized, return self
|
||||||
new_tensor = self.materialize() if type(self) is LazyTensor else self
|
new_tensor = self.materialize() if type(self) is LazyTensor else self
|
||||||
copied = new_tensor.detach().clone()
|
return _copy_tensor(new_tensor, new_tensor.requires_grad)
|
||||||
if new_tensor.requires_grad:
|
|
||||||
copied.requires_grad_()
|
|
||||||
return copied
|
|
||||||
|
|
||||||
if self._materialized_data is not None:
|
if self._materialized_data is not None:
|
||||||
# self is early materialized
|
# self is early materialized
|
||||||
copied = self._materialized_data.detach().clone()
|
copied = _copy_tensor(self._materialized_data, self.requires_grad)
|
||||||
if self.requires_grad:
|
|
||||||
copied.requires_grad_()
|
|
||||||
target = LazyTensor(lambda: None, concrete_data=copied)
|
target = LazyTensor(lambda: None, concrete_data=copied)
|
||||||
else:
|
else:
|
||||||
target = LazyTensor(factory_fn, meta_data=self._meta_data)
|
target = LazyTensor(factory_fn, meta_data=self._meta_data)
|
||||||
|
|
||||||
|
if isinstance(self, Parameter):
|
||||||
|
# hack isinstance check of parameter
|
||||||
|
target._is_param = True
|
||||||
|
|
||||||
memo[id(self)] = target
|
memo[id(self)] = target
|
||||||
return target
|
return target
|
||||||
|
|
||||||
@ -404,6 +407,10 @@ class LazyTensor(torch.Tensor):
|
|||||||
def __hash__(self):
|
def __hash__(self):
|
||||||
return id(self)
|
return id(self)
|
||||||
|
|
||||||
|
def __rpow__(self, other):
|
||||||
|
dtype = torch.result_type(self, other)
|
||||||
|
return torch.tensor(other, dtype=dtype, device=self.device)**self
|
||||||
|
|
||||||
|
|
||||||
class LazyInitContext:
|
class LazyInitContext:
|
||||||
"""Context manager for lazy initialization. Enables initializing the model without allocating real memory.
|
"""Context manager for lazy initialization. Enables initializing the model without allocating real memory.
|
||||||
@ -524,7 +531,7 @@ class LazyInitContext:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def materialize(module: nn.Module, verbose: bool = False) -> nn.Module:
|
def materialize(module: nn.Module, verbose: bool = False) -> nn.Module:
|
||||||
"""Initialize all ``nn.Parameter`` from ``LazyTensor``. This function will modify the module in-place.
|
"""Initialize all ``Parameter`` from ``LazyTensor``. This function will modify the module in-place.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
module (nn.Module): Target ``nn.Module``
|
module (nn.Module): Target ``nn.Module``
|
||||||
@ -541,7 +548,7 @@ class LazyInitContext:
|
|||||||
device_mesh: DeviceMesh,
|
device_mesh: DeviceMesh,
|
||||||
sharding_spec_dict: Dict[str, ShardingSpec],
|
sharding_spec_dict: Dict[str, ShardingSpec],
|
||||||
verbose: bool = False) -> nn.Module:
|
verbose: bool = False) -> nn.Module:
|
||||||
"""Distribute all ``nn.Parameter`` from ``LazyTensor``. This function will modify the module in-place.
|
"""Distribute all ``Parameter`` from ``LazyTensor``. This function will modify the module in-place.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
module (nn.Module): Target ``nn.Module``
|
module (nn.Module): Target ``nn.Module``
|
||||||
@ -613,3 +620,9 @@ def _is_int_tuple(args) -> bool:
|
|||||||
if not isinstance(x, int):
|
if not isinstance(x, int):
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def _copy_tensor(tensor: Tensor, requires_grad: bool) -> Tensor:
|
||||||
|
copied = tensor.data.clone()
|
||||||
|
copied.requires_grad = requires_grad
|
||||||
|
return copied
|
||||||
|
@ -9,8 +9,8 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
from torch.nn.parameter import Parameter
|
|
||||||
|
|
||||||
|
from colossalai.lazy import LazyInitContext
|
||||||
from colossalai.nn import init as init
|
from colossalai.nn import init as init
|
||||||
from colossalai.nn.layer.utils import divide
|
from colossalai.nn.layer.utils import divide
|
||||||
from colossalai.tensor.d_tensor.api import shard_colwise, shard_rowwise, sharded_tensor_to_param
|
from colossalai.tensor.d_tensor.api import shard_colwise, shard_rowwise, sharded_tensor_to_param
|
||||||
@ -95,6 +95,7 @@ class Embedding1D(ParallelModule):
|
|||||||
r"""
|
r"""
|
||||||
Build a 1D parallelized Embedding from a native nn.Embedding module.
|
Build a 1D parallelized Embedding from a native nn.Embedding module.
|
||||||
"""
|
"""
|
||||||
|
LazyInitContext.materialize(module)
|
||||||
# get the attributes
|
# get the attributes
|
||||||
num_embedding = module.num_embeddings
|
num_embedding = module.num_embeddings
|
||||||
embedding_dim = module.embedding_dim
|
embedding_dim = module.embedding_dim
|
||||||
@ -223,6 +224,7 @@ class VocabParallelEmbedding1D(ParallelModule):
|
|||||||
r"""
|
r"""
|
||||||
Convert a native pytorch embedding module to a parallel module.
|
Convert a native pytorch embedding module to a parallel module.
|
||||||
"""
|
"""
|
||||||
|
LazyInitContext.materialize(module)
|
||||||
# get the origin attributes
|
# get the origin attributes
|
||||||
num_embeddings = module.num_embeddings
|
num_embeddings = module.num_embeddings
|
||||||
embedding_dim = module.embedding_dim
|
embedding_dim = module.embedding_dim
|
||||||
@ -243,6 +245,7 @@ class VocabParallelEmbedding1D(ParallelModule):
|
|||||||
process_group=process_group,
|
process_group=process_group,
|
||||||
*args,
|
*args,
|
||||||
**kwargs)
|
**kwargs)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# shard and slice the weight along the vocabulary(num_embeddings) dimension
|
# shard and slice the weight along the vocabulary(num_embeddings) dimension
|
||||||
# the shape of the weight is (num_embeddings, embedding_dim)
|
# the shape of the weight is (num_embeddings, embedding_dim)
|
||||||
|
@ -12,6 +12,7 @@ from torch import Tensor
|
|||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
|
from colossalai.lazy import LazyInitContext
|
||||||
from colossalai.nn import init as init
|
from colossalai.nn import init as init
|
||||||
from colossalai.nn.layer.utils import divide
|
from colossalai.nn.layer.utils import divide
|
||||||
from colossalai.tensor.d_tensor import shard_colwise, shard_rowwise, sharded_tensor_to_param
|
from colossalai.tensor.d_tensor import shard_colwise, shard_rowwise, sharded_tensor_to_param
|
||||||
@ -106,6 +107,7 @@ class Linear1D_Col(ParallelModule):
|
|||||||
r"""
|
r"""
|
||||||
Convert a native PyTorch linear layer to a parallelized linear layer.
|
Convert a native PyTorch linear layer to a parallelized linear layer.
|
||||||
"""
|
"""
|
||||||
|
LazyInitContext.materialize(module)
|
||||||
# get the attributes
|
# get the attributes
|
||||||
in_features = module.in_features
|
in_features = module.in_features
|
||||||
out_features = module.out_features
|
out_features = module.out_features
|
||||||
@ -242,6 +244,7 @@ class Linear1D_Row(ParallelModule):
|
|||||||
r"""
|
r"""
|
||||||
Convert a native PyTorch linear layer to a parallelized linear layer.
|
Convert a native PyTorch linear layer to a parallelized linear layer.
|
||||||
"""
|
"""
|
||||||
|
LazyInitContext.materialize(module)
|
||||||
# get the attributes
|
# get the attributes
|
||||||
in_features = module.in_features
|
in_features = module.in_features
|
||||||
out_features = module.out_features
|
out_features = module.out_features
|
||||||
|
@ -4,6 +4,8 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from colossalai.lazy import LazyInitContext
|
||||||
|
|
||||||
__all__ = ['FusedLayerNorm', 'FusedRMSNorm']
|
__all__ = ['FusedLayerNorm', 'FusedRMSNorm']
|
||||||
|
|
||||||
FAST_LAYERNORM_SUPPORTED_SIZE = [
|
FAST_LAYERNORM_SUPPORTED_SIZE = [
|
||||||
@ -35,6 +37,7 @@ class FusedLayerNorm():
|
|||||||
raise ImportError(
|
raise ImportError(
|
||||||
'Please install apex from source (https://github.com/NVIDIA/apex) to use the fused layernorm kernel')
|
'Please install apex from source (https://github.com/NVIDIA/apex) to use the fused layernorm kernel')
|
||||||
|
|
||||||
|
LazyInitContext.materialize(module)
|
||||||
# get the attributes of the module
|
# get the attributes of the module
|
||||||
normalized_shape = module.normalized_shape
|
normalized_shape = module.normalized_shape
|
||||||
eps = module.eps
|
eps = module.eps
|
||||||
@ -84,6 +87,7 @@ class FusedRMSNorm():
|
|||||||
'Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMS normalization kernel'
|
'Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMS normalization kernel'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
LazyInitContext.materialize(module)
|
||||||
# to check if it is huggingface LlamaRMSNorm
|
# to check if it is huggingface LlamaRMSNorm
|
||||||
if module.__class__.__name__ == "LlamaRMSNorm":
|
if module.__class__.__name__ == "LlamaRMSNorm":
|
||||||
normalized_shape = module.weight.shape[0]
|
normalized_shape = module.weight.shape[0]
|
||||||
|
@ -12,6 +12,7 @@ from torch import Tensor
|
|||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
|
from colossalai.lazy import LazyInitContext
|
||||||
from colossalai.nn import init as init
|
from colossalai.nn import init as init
|
||||||
from colossalai.nn.layer.utils import divide
|
from colossalai.nn.layer.utils import divide
|
||||||
from colossalai.tensor.d_tensor.api import (
|
from colossalai.tensor.d_tensor.api import (
|
||||||
@ -231,6 +232,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
|
|||||||
process_group (`Union[ProcessGroup, List[ProcessGroup]]`): The process group to be used for weight sharding and communication.
|
process_group (`Union[ProcessGroup, List[ProcessGroup]]`): The process group to be used for weight sharding and communication.
|
||||||
n_fused (int): The number of layers to be fused. In GPT2, Q,K,V are fused in one weight.
|
n_fused (int): The number of layers to be fused. In GPT2, Q,K,V are fused in one weight.
|
||||||
"""
|
"""
|
||||||
|
LazyInitContext.materialize(module)
|
||||||
# get the attributes
|
# get the attributes
|
||||||
in_features = module.weight.shape[0]
|
in_features = module.weight.shape[0]
|
||||||
out_features = module.weight.shape[1]
|
out_features = module.weight.shape[1]
|
||||||
@ -380,6 +382,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
|
|||||||
r"""
|
r"""
|
||||||
Convert a native PyTorch linear layer to a parallelized linear layer.
|
Convert a native PyTorch linear layer to a parallelized linear layer.
|
||||||
"""
|
"""
|
||||||
|
LazyInitContext.materialize(module)
|
||||||
# get the attributes
|
# get the attributes
|
||||||
in_features = module.weight.shape[0]
|
in_features = module.weight.shape[0]
|
||||||
out_features = module.weight.shape[1]
|
out_features = module.weight.shape[1]
|
||||||
@ -428,9 +431,9 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
|
|||||||
src_rank = dist.distributed_c10d._get_global_rank(self.process_group, 0)
|
src_rank = dist.distributed_c10d._get_global_rank(self.process_group, 0)
|
||||||
|
|
||||||
origin_device = self.bias.device
|
origin_device = self.bias.device
|
||||||
self.bias = self.bias.cuda()
|
self.bias.data = self.bias.cuda()
|
||||||
dist.broadcast(self.bias, src=src_rank, group=self.process_group)
|
dist.broadcast(self.bias, src=src_rank, group=self.process_group)
|
||||||
self.bias = self.bias.to(origin_device)
|
self.bias.data = self.bias.to(origin_device)
|
||||||
|
|
||||||
def forward(self, input_: Tensor) -> Tensor:
|
def forward(self, input_: Tensor) -> Tensor:
|
||||||
# Set up backprop all-reduce.
|
# Set up backprop all-reduce.
|
||||||
|
@ -46,11 +46,12 @@ class BertPolicy(Policy):
|
|||||||
Reshape the Embedding layer to make the embedding dimension divisible by world_size
|
Reshape the Embedding layer to make the embedding dimension divisible by world_size
|
||||||
"""
|
"""
|
||||||
# TODO:
|
# TODO:
|
||||||
vocab_size = self.model.config.vocab_size
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
world_size = self.shard_config.tensor_parallel_size
|
vocab_size = self.model.config.vocab_size
|
||||||
if vocab_size % world_size != 0:
|
world_size = self.shard_config.tensor_parallel_size
|
||||||
new_vocab_size = vocab_size + world_size - vocab_size % world_size
|
if vocab_size % world_size != 0:
|
||||||
self.model.resize_token_embeddings(new_vocab_size)
|
new_vocab_size = vocab_size + world_size - vocab_size % world_size
|
||||||
|
self.model.resize_token_embeddings(new_vocab_size)
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
def module_policy(self):
|
def module_policy(self):
|
||||||
@ -229,10 +230,11 @@ class BertForPreTrainingPolicy(BertPolicy):
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
def postprocess(self):
|
def postprocess(self):
|
||||||
binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"}
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
for k, v in binding_map.items():
|
binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"}
|
||||||
param = getattr_(self.model, k)
|
for k, v in binding_map.items():
|
||||||
setattr_(self.model, v, param)
|
param = getattr_(self.model, k)
|
||||||
|
setattr_(self.model, v, param)
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
|
|
||||||
@ -269,10 +271,11 @@ class BertLMHeadModelPolicy(BertPolicy):
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
def postprocess(self):
|
def postprocess(self):
|
||||||
binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"}
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
for k, v in binding_map.items():
|
binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"}
|
||||||
param = getattr_(self.model, k)
|
for k, v in binding_map.items():
|
||||||
setattr_(self.model, v, param)
|
param = getattr_(self.model, k)
|
||||||
|
setattr_(self.model, v, param)
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
|
|
||||||
@ -288,10 +291,11 @@ class BertForMaskedLMPolicy(BertPolicy):
|
|||||||
return module_policy
|
return module_policy
|
||||||
|
|
||||||
def postprocess(self):
|
def postprocess(self):
|
||||||
binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"}
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
for k, v in binding_map.items():
|
binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"}
|
||||||
param = getattr_(self.model, k)
|
for k, v in binding_map.items():
|
||||||
setattr_(self.model, v, param)
|
param = getattr_(self.model, k)
|
||||||
|
setattr_(self.model, v, param)
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
|
|
||||||
|
@ -17,11 +17,12 @@ class BloomPolicy(Policy):
|
|||||||
r"""
|
r"""
|
||||||
Reshape the Embedding layer to make the embedding dimension divisible by world_size
|
Reshape the Embedding layer to make the embedding dimension divisible by world_size
|
||||||
"""
|
"""
|
||||||
vocab_size = self.model.config.vocab_size
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
world_size = self.shard_config.tensor_parallel_size
|
vocab_size = self.model.config.vocab_size
|
||||||
if vocab_size % world_size != 0:
|
world_size = self.shard_config.tensor_parallel_size
|
||||||
new_vocab_size = vocab_size + world_size - vocab_size % world_size
|
if vocab_size % world_size != 0:
|
||||||
self.model.resize_token_embeddings(new_vocab_size)
|
new_vocab_size = vocab_size + world_size - vocab_size % world_size
|
||||||
|
self.model.resize_token_embeddings(new_vocab_size)
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
def module_policy(self):
|
def module_policy(self):
|
||||||
@ -128,16 +129,13 @@ class BloomForCausalLMPolicy(BloomPolicy):
|
|||||||
return policy
|
return policy
|
||||||
|
|
||||||
def postprocess(self):
|
def postprocess(self):
|
||||||
binding_map = {"transformer.word_embeddings.weight": "lm_head.weight"}
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
|
binding_map = {"transformer.word_embeddings.weight": "lm_head.weight"}
|
||||||
|
|
||||||
for k, v in binding_map.items():
|
for k, v in binding_map.items():
|
||||||
param = getattr_(self.model, k)
|
param = getattr_(self.model, k)
|
||||||
|
# tie weights
|
||||||
if not isinstance(param, nn.Parameter):
|
setattr_(self.model, v, param)
|
||||||
param = nn.Parameter(param)
|
|
||||||
|
|
||||||
# tie weights
|
|
||||||
setattr_(self.model, v, param)
|
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
|
|
||||||
|
@ -21,11 +21,12 @@ class GPT2Policy(Policy):
|
|||||||
r"""
|
r"""
|
||||||
Reshape the Embedding layer to make the embedding dimension divisible by world_size
|
Reshape the Embedding layer to make the embedding dimension divisible by world_size
|
||||||
"""
|
"""
|
||||||
vocab_size = self.model.config.vocab_size
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
world_size = self.shard_config.tensor_parallel_size
|
vocab_size = self.model.config.vocab_size
|
||||||
if vocab_size % world_size != 0:
|
world_size = self.shard_config.tensor_parallel_size
|
||||||
new_vocab_size = vocab_size + world_size - vocab_size % world_size
|
if vocab_size % world_size != 0:
|
||||||
self.model.resize_token_embeddings(new_vocab_size)
|
new_vocab_size = vocab_size + world_size - vocab_size % world_size
|
||||||
|
self.model.resize_token_embeddings(new_vocab_size)
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
def module_policy(self):
|
def module_policy(self):
|
||||||
@ -142,10 +143,11 @@ class GPT2LMHeadModelPolicy(GPT2Policy):
|
|||||||
return module_policy
|
return module_policy
|
||||||
|
|
||||||
def postprocess(self):
|
def postprocess(self):
|
||||||
binding_map = {"transformer.wte.weight": "lm_head.weight"}
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
for k, v in binding_map.items():
|
binding_map = {"transformer.wte.weight": "lm_head.weight"}
|
||||||
param = getattr_(self.model, k)
|
for k, v in binding_map.items():
|
||||||
setattr_(self.model, v, param)
|
param = getattr_(self.model, k)
|
||||||
|
setattr_(self.model, v, param)
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
|
|
||||||
@ -172,10 +174,11 @@ class GPT2DoubleHeadsModelPolicy(GPT2Policy):
|
|||||||
return module_policy
|
return module_policy
|
||||||
|
|
||||||
def postprocess(self):
|
def postprocess(self):
|
||||||
binding_map = {"transformer.wte.weight": "lm_head.weight"}
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
for k, v in binding_map.items():
|
binding_map = {"transformer.wte.weight": "lm_head.weight"}
|
||||||
param = getattr_(self.model, k)
|
for k, v in binding_map.items():
|
||||||
setattr_(self.model, v, param)
|
param = getattr_(self.model, k)
|
||||||
|
setattr_(self.model, v, param)
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
|
|
||||||
|
@ -15,13 +15,14 @@ class LlamaPolicy(Policy):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def preprocess(self):
|
def preprocess(self):
|
||||||
# Resize embedding
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
vocab_size = self.model.config.vocab_size
|
# Resize embedding
|
||||||
world_size = self.shard_config.tensor_parallel_size
|
vocab_size = self.model.config.vocab_size
|
||||||
|
world_size = self.shard_config.tensor_parallel_size
|
||||||
|
|
||||||
if vocab_size % world_size != 0:
|
if vocab_size % world_size != 0:
|
||||||
new_vocab_size = vocab_size + world_size - vocab_size % world_size
|
new_vocab_size = vocab_size + world_size - vocab_size % world_size
|
||||||
self.model.resize_token_embeddings(new_vocab_size)
|
self.model.resize_token_embeddings(new_vocab_size)
|
||||||
|
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
|
@ -19,11 +19,12 @@ class OPTPolicy(Policy):
|
|||||||
r"""
|
r"""
|
||||||
Reshape the Embedding layer to make the embedding dimension divisible by world_size
|
Reshape the Embedding layer to make the embedding dimension divisible by world_size
|
||||||
"""
|
"""
|
||||||
vocab_size = self.model.config.vocab_size
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
world_size = self.shard_config.tensor_parallel_size
|
vocab_size = self.model.config.vocab_size
|
||||||
if vocab_size % world_size != 0:
|
world_size = self.shard_config.tensor_parallel_size
|
||||||
new_vocab_size = vocab_size + world_size - vocab_size % world_size
|
if vocab_size % world_size != 0:
|
||||||
self.model.resize_token_embeddings(new_vocab_size)
|
new_vocab_size = vocab_size + world_size - vocab_size % world_size
|
||||||
|
self.model.resize_token_embeddings(new_vocab_size)
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
def module_policy(self):
|
def module_policy(self):
|
||||||
@ -116,14 +117,15 @@ class OPTForCausalLMPolicy(OPTPolicy):
|
|||||||
return policy
|
return policy
|
||||||
|
|
||||||
def postprocess(self):
|
def postprocess(self):
|
||||||
binding_map = {
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
'model.decoder.embed_tokens': 'lm_head',
|
binding_map = {
|
||||||
}
|
'model.decoder.embed_tokens': 'lm_head',
|
||||||
|
}
|
||||||
|
|
||||||
for k, v in binding_map.items():
|
for k, v in binding_map.items():
|
||||||
src_mod = getattr_(self.model, k)
|
src_mod = getattr_(self.model, k)
|
||||||
dst_mod = getattr_(self.model, v)
|
dst_mod = getattr_(self.model, v)
|
||||||
dst_mod.weight = src_mod.weight
|
dst_mod.weight = src_mod.weight
|
||||||
|
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
|
@ -24,11 +24,12 @@ class T5BasePolicy(Policy):
|
|||||||
r"""
|
r"""
|
||||||
Reshape the Embedding layer to make the embedding dimension divisible by world_size
|
Reshape the Embedding layer to make the embedding dimension divisible by world_size
|
||||||
"""
|
"""
|
||||||
vocab_size = self.model.config.vocab_size
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
world_size = self.shard_config.tensor_parallel_size
|
vocab_size = self.model.config.vocab_size
|
||||||
if vocab_size % world_size != 0:
|
world_size = self.shard_config.tensor_parallel_size
|
||||||
new_vocab_size = vocab_size + world_size - vocab_size % world_size
|
if vocab_size % world_size != 0:
|
||||||
self.model.resize_token_embeddings(new_vocab_size)
|
new_vocab_size = vocab_size + world_size - vocab_size % world_size
|
||||||
|
self.model.resize_token_embeddings(new_vocab_size)
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
def module_policy(self):
|
def module_policy(self):
|
||||||
@ -164,11 +165,12 @@ class T5BasePolicy(Policy):
|
|||||||
return policy
|
return policy
|
||||||
|
|
||||||
def postprocess(self):
|
def postprocess(self):
|
||||||
binding_map = [["shared", "encoder.embed_tokens"], ["shared", "decoder.embed_tokens"]]
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
|
binding_map = [["shared", "encoder.embed_tokens"], ["shared", "decoder.embed_tokens"]]
|
||||||
|
|
||||||
for k, v in binding_map:
|
for k, v in binding_map:
|
||||||
mod = getattr_(self.model, k)
|
mod = getattr_(self.model, k)
|
||||||
setattr_(self.model, v, mod)
|
setattr_(self.model, v, mod)
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
|
|
||||||
@ -211,13 +213,13 @@ class T5ForConditionalGenerationPolicy(T5BasePolicy):
|
|||||||
|
|
||||||
def postprocess(self):
|
def postprocess(self):
|
||||||
super().postprocess()
|
super().postprocess()
|
||||||
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
|
binding_map = {"shared": "lm_head"}
|
||||||
|
|
||||||
binding_map = {"shared": "lm_head"}
|
for k, v in binding_map.items():
|
||||||
|
src_mod = getattr_(self.model, k)
|
||||||
for k, v in binding_map.items():
|
dst_mod = getattr_(self.model, v)
|
||||||
src_mod = getattr_(self.model, k)
|
dst_mod.weight = src_mod.weight
|
||||||
dst_mod = getattr_(self.model, v)
|
|
||||||
dst_mod.weight = src_mod.weight
|
|
||||||
|
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
@ -239,11 +241,12 @@ class T5EncoderPolicy(T5BasePolicy):
|
|||||||
return base_policy
|
return base_policy
|
||||||
|
|
||||||
def postprocess(self):
|
def postprocess(self):
|
||||||
binding_map = [
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
["shared", "encoder.embed_tokens"],
|
binding_map = [
|
||||||
]
|
["shared", "encoder.embed_tokens"],
|
||||||
|
]
|
||||||
|
|
||||||
for k, v in binding_map:
|
for k, v in binding_map:
|
||||||
mod = getattr_(self.model, k)
|
mod = getattr_(self.model, k)
|
||||||
setattr_(self.model, v, mod)
|
setattr_(self.model, v, mod)
|
||||||
return self.model
|
return self.model
|
||||||
|
@ -3,7 +3,7 @@ from typing import Any, Callable, Dict, List, Union
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from colossalai.lazy import LazyTensor
|
from colossalai.lazy import LazyInitContext
|
||||||
|
|
||||||
from .._utils import getattr_, setattr_
|
from .._utils import getattr_, setattr_
|
||||||
from ..policies.auto_policy import get_autopolicy
|
from ..policies.auto_policy import get_autopolicy
|
||||||
@ -192,10 +192,4 @@ class ModelSharder(object):
|
|||||||
r"""
|
r"""
|
||||||
Materialize the model if lazy initialization is used
|
Materialize the model if lazy initialization is used
|
||||||
"""
|
"""
|
||||||
for p in self.model.parameters():
|
LazyInitContext.materialize(self.model)
|
||||||
if isinstance(p, LazyTensor):
|
|
||||||
p.materialize()
|
|
||||||
|
|
||||||
for b in self.model.buffers():
|
|
||||||
if isinstance(b, LazyTensor):
|
|
||||||
b.materialize()
|
|
||||||
|
@ -1,15 +1,22 @@
|
|||||||
|
from contextlib import nullcontext
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.testing import assert_close
|
from torch.testing import assert_close
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
|
from colossalai.lazy import LazyInitContext
|
||||||
from colossalai.shardformer.layer import Embedding1D
|
from colossalai.shardformer.layer import Embedding1D
|
||||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||||
|
|
||||||
|
|
||||||
def check_embedding_1d():
|
@parameterize('lazy_init', [False, True])
|
||||||
embedding = nn.Embedding(32, 128).cuda()
|
def check_embedding_1d(lazy_init: bool):
|
||||||
|
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||||
|
|
||||||
|
with ctx:
|
||||||
|
embedding = nn.Embedding(32, 128).cuda()
|
||||||
embedding_1d = Embedding1D.from_native_module(embedding, process_group=None)
|
embedding_1d = Embedding1D.from_native_module(embedding, process_group=None)
|
||||||
|
|
||||||
assert embedding_1d.weight.shape == torch.Size([32, 64])
|
assert embedding_1d.weight.shape == torch.Size([32, 64])
|
||||||
|
@ -1,14 +1,21 @@
|
|||||||
|
from contextlib import nullcontext
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.testing import assert_close
|
from torch.testing import assert_close
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
|
from colossalai.lazy import LazyInitContext
|
||||||
from colossalai.shardformer.layer import FusedLayerNorm
|
from colossalai.shardformer.layer import FusedLayerNorm
|
||||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||||
|
|
||||||
|
|
||||||
def check_layernorm():
|
@parameterize('lazy_init', [False, True])
|
||||||
norm = nn.LayerNorm(128, 0.00001).cuda()
|
def check_layernorm(lazy_init: bool):
|
||||||
|
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||||
|
|
||||||
|
with ctx:
|
||||||
|
norm = nn.LayerNorm(128, 0.00001).cuda()
|
||||||
norm1d = FusedLayerNorm.from_native_module(norm, process_group=None)
|
norm1d = FusedLayerNorm.from_native_module(norm, process_group=None)
|
||||||
|
|
||||||
assert norm1d.weight.shape == torch.Size([128])
|
assert norm1d.weight.shape == torch.Size([128])
|
||||||
|
@ -1,16 +1,23 @@
|
|||||||
|
from contextlib import nullcontext
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.testing import assert_close
|
from torch.testing import assert_close
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
|
from colossalai.lazy import LazyInitContext
|
||||||
from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row
|
from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row
|
||||||
from colossalai.tensor.d_tensor import is_distributed_tensor
|
from colossalai.tensor.d_tensor import is_distributed_tensor
|
||||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||||
|
|
||||||
|
|
||||||
def check_linear_1d_col():
|
@parameterize('lazy_init', [False, True])
|
||||||
linear = nn.Linear(32, 128).cuda()
|
def check_linear_1d_col(lazy_init: bool):
|
||||||
|
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||||
|
|
||||||
|
with ctx:
|
||||||
|
linear = nn.Linear(32, 128).cuda()
|
||||||
linear_col = Linear1D_Col.from_native_module(linear, process_group=None, gather_output=True)
|
linear_col = Linear1D_Col.from_native_module(linear, process_group=None, gather_output=True)
|
||||||
|
|
||||||
# ensure that the parameters are distributed
|
# ensure that the parameters are distributed
|
||||||
@ -50,8 +57,12 @@ def check_linear_1d_col():
|
|||||||
assert_close(x_for_unshard.grad, x_for_shard.grad)
|
assert_close(x_for_unshard.grad, x_for_shard.grad)
|
||||||
|
|
||||||
|
|
||||||
def check_linear_1d_row():
|
@parameterize('lazy_init', [False, True])
|
||||||
linear = nn.Linear(32, 128).cuda()
|
def check_linear_1d_row(lazy_init: bool):
|
||||||
|
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||||
|
|
||||||
|
with ctx:
|
||||||
|
linear = nn.Linear(32, 128).cuda()
|
||||||
linear_row = Linear1D_Row.from_native_module(linear, process_group=None, parallel_input=False)
|
linear_row = Linear1D_Row.from_native_module(linear, process_group=None, parallel_input=False)
|
||||||
|
|
||||||
assert linear_row.weight.shape == torch.Size([128, 16])
|
assert linear_row.weight.shape == torch.Size([128, 16])
|
||||||
@ -83,9 +94,13 @@ def check_linear_1d_row():
|
|||||||
assert_close(x_for_unshard.grad, x_for_shard.grad)
|
assert_close(x_for_unshard.grad, x_for_shard.grad)
|
||||||
|
|
||||||
|
|
||||||
def check_linear_col_plus_row():
|
@parameterize('lazy_init', [False, True])
|
||||||
linear_1 = nn.Linear(32, 128).cuda()
|
def check_linear_col_plus_row(lazy_init: bool):
|
||||||
linear_2 = nn.Linear(128, 32).cuda()
|
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||||
|
|
||||||
|
with ctx:
|
||||||
|
linear_1 = nn.Linear(32, 128).cuda()
|
||||||
|
linear_2 = nn.Linear(128, 32).cuda()
|
||||||
linear_col = Linear1D_Col.from_native_module(linear_1, process_group=None, gather_output=False)
|
linear_col = Linear1D_Col.from_native_module(linear_1, process_group=None, gather_output=False)
|
||||||
linear_row = Linear1D_Row.from_native_module(linear_2, process_group=None, parallel_input=True)
|
linear_row = Linear1D_Row.from_native_module(linear_2, process_group=None, parallel_input=True)
|
||||||
|
|
||||||
|
@ -1,12 +1,15 @@
|
|||||||
|
from contextlib import nullcontext
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.testing import assert_close
|
from torch.testing import assert_close
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
|
from colossalai.lazy import LazyInitContext
|
||||||
from colossalai.shardformer.layer import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
|
from colossalai.shardformer.layer import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
|
||||||
from colossalai.shardformer.layer.qkv_fused_linear import split_fused_qkv_in_gpt2_style
|
from colossalai.shardformer.layer.qkv_fused_linear import split_fused_qkv_in_gpt2_style
|
||||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||||
|
|
||||||
|
|
||||||
# This code is copied from https://github.com/huggingface/transformers
|
# This code is copied from https://github.com/huggingface/transformers
|
||||||
@ -50,8 +53,12 @@ def rearrange(tensor: torch.Tensor, dim: int):
|
|||||||
return rearanged_tensor
|
return rearanged_tensor
|
||||||
|
|
||||||
|
|
||||||
def check_linear_conv_1d_col():
|
@parameterize('lazy_init', [False, True])
|
||||||
linear = Conv1D(192, 48).cuda()
|
def check_linear_conv_1d_col(lazy_init: bool):
|
||||||
|
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||||
|
|
||||||
|
with ctx:
|
||||||
|
linear = Conv1D(192, 48).cuda()
|
||||||
linear_conv_col = GPT2FusedLinearConv1D_Col.from_native_module(linear,
|
linear_conv_col = GPT2FusedLinearConv1D_Col.from_native_module(linear,
|
||||||
process_group=None,
|
process_group=None,
|
||||||
gather_output=True,
|
gather_output=True,
|
||||||
@ -80,8 +87,12 @@ def check_linear_conv_1d_col():
|
|||||||
assert_close(target_grad, linear_conv_col.weight.grad)
|
assert_close(target_grad, linear_conv_col.weight.grad)
|
||||||
|
|
||||||
|
|
||||||
def check_linear_conv_1d_row():
|
@parameterize('lazy_init', [False, True])
|
||||||
linear = Conv1D(192, 48).cuda()
|
def check_linear_conv_1d_row(lazy_init: bool):
|
||||||
|
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||||
|
|
||||||
|
with ctx:
|
||||||
|
linear = Conv1D(192, 48).cuda()
|
||||||
linear_row = GPT2FusedLinearConv1D_Row.from_native_module(linear, process_group=None, parallel_input=False)
|
linear_row = GPT2FusedLinearConv1D_Row.from_native_module(linear, process_group=None, parallel_input=False)
|
||||||
|
|
||||||
assert linear.weight.shape == torch.Size([48, 192])
|
assert linear.weight.shape == torch.Size([48, 192])
|
||||||
|
@ -1,15 +1,23 @@
|
|||||||
|
from contextlib import nullcontext
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.testing import assert_close
|
from torch.testing import assert_close
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.shardformer.layer import VocabParallelEmbedding1D
|
from colossalai.lazy import LazyInitContext
|
||||||
|
from colossalai.shardformer.layer import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row, VocabParallelEmbedding1D
|
||||||
|
from colossalai.shardformer.layer.qkv_fused_linear import split_fused_qkv_in_gpt2_style
|
||||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||||
|
|
||||||
|
|
||||||
def check_vocab_embedding_1d():
|
@parameterize('lazy_init', [False, True])
|
||||||
embedding = nn.Embedding(128, 32).to('cuda')
|
def check_vocab_embedding_1d(lazy_init: bool):
|
||||||
|
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||||
|
|
||||||
|
with ctx:
|
||||||
|
embedding = nn.Embedding(128, 32).to('cuda')
|
||||||
dist_embedding_1d = VocabParallelEmbedding1D.from_native_module(embedding, process_group=None)
|
dist_embedding_1d = VocabParallelEmbedding1D.from_native_module(embedding, process_group=None)
|
||||||
|
|
||||||
assert dist_embedding_1d.weight.shape == torch.Size([64, 32])
|
assert dist_embedding_1d.weight.shape == torch.Size([64, 32])
|
||||||
|
@ -1,19 +1,24 @@
|
|||||||
import copy
|
import copy
|
||||||
|
from contextlib import nullcontext
|
||||||
|
|
||||||
|
from colossalai.lazy import LazyInitContext
|
||||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||||
|
|
||||||
|
|
||||||
def build_model(model_fn, enable_fused_normalization=True, enable_tensor_parallelism=True):
|
def build_model(model_fn, enable_fused_normalization=True, enable_tensor_parallelism=True, use_lazy_init: bool = False):
|
||||||
# create new model
|
ctx = LazyInitContext() if use_lazy_init else nullcontext()
|
||||||
org_model = model_fn().cuda()
|
with ctx:
|
||||||
|
# create new model
|
||||||
|
org_model = model_fn()
|
||||||
|
model_copy = copy.deepcopy(org_model)
|
||||||
|
if use_lazy_init:
|
||||||
|
ctx.materialize(org_model)
|
||||||
# shard model
|
# shard model
|
||||||
shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization,
|
shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization,
|
||||||
enable_tensor_parallelism=enable_tensor_parallelism)
|
enable_tensor_parallelism=enable_tensor_parallelism)
|
||||||
model_copy = copy.deepcopy(org_model)
|
|
||||||
shard_former = ShardFormer(shard_config=shard_config)
|
shard_former = ShardFormer(shard_config=shard_config)
|
||||||
sharded_model, shared_params = shard_former.optimize(model_copy)
|
sharded_model, shared_params = shard_former.optimize(model_copy)
|
||||||
return org_model, sharded_model.cuda()
|
return org_model.cuda(), sharded_model.cuda()
|
||||||
|
|
||||||
|
|
||||||
def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
|
def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
|
||||||
|
@ -67,12 +67,14 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
|
|||||||
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"
|
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"
|
||||||
|
|
||||||
|
|
||||||
@parameterize('enable_fused_normalization', [True, False])
|
@parameterize('enable_fused_normalization', [False, True])
|
||||||
@parameterize('enable_tensor_parallelism', [True, False])
|
@parameterize('enable_tensor_parallelism', [False, True])
|
||||||
def run_bert_test(enable_fused_normalization, enable_tensor_parallelism):
|
@parameterize('use_lazy_init', [False, True])
|
||||||
|
def run_bert_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
|
||||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_bert')
|
sub_model_zoo = model_zoo.get_sub_registry('transformers_bert')
|
||||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||||
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism)
|
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
|
||||||
|
use_lazy_init)
|
||||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
@ -69,10 +69,12 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
|
|||||||
|
|
||||||
@parameterize('enable_fused_normalization', [True, False])
|
@parameterize('enable_fused_normalization', [True, False])
|
||||||
@parameterize('enable_tensor_parallelism', [True, False])
|
@parameterize('enable_tensor_parallelism', [True, False])
|
||||||
def run_bloom_test(enable_fused_normalization, enable_tensor_parallelism):
|
@parameterize('use_lazy_init', [False, True])
|
||||||
|
def run_bloom_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
|
||||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom')
|
sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom')
|
||||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||||
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism)
|
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
|
||||||
|
use_lazy_init)
|
||||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
@ -69,10 +69,12 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
|
|||||||
|
|
||||||
@parameterize('enable_fused_normalization', [True, False])
|
@parameterize('enable_fused_normalization', [True, False])
|
||||||
@parameterize('enable_tensor_parallelism', [True, False])
|
@parameterize('enable_tensor_parallelism', [True, False])
|
||||||
def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism):
|
@parameterize('use_lazy_init', [False, True])
|
||||||
|
def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
|
||||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
|
sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
|
||||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||||
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism)
|
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
|
||||||
|
use_lazy_init)
|
||||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
@ -72,10 +72,12 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
|
|||||||
|
|
||||||
@parameterize('enable_fused_normalization', [True, False])
|
@parameterize('enable_fused_normalization', [True, False])
|
||||||
@parameterize('enable_tensor_parallelism', [True, False])
|
@parameterize('enable_tensor_parallelism', [True, False])
|
||||||
def run_gpt2_llama(enable_fused_normalization, enable_tensor_parallelism):
|
@parameterize('use_lazy_init', [False, True])
|
||||||
|
def run_gpt2_llama(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
|
||||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_llama')
|
sub_model_zoo = model_zoo.get_sub_registry('transformers_llama')
|
||||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||||
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism)
|
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
|
||||||
|
use_lazy_init)
|
||||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
@ -71,10 +71,12 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
|
|||||||
|
|
||||||
@parameterize('enable_fused_normalization', [True, False])
|
@parameterize('enable_fused_normalization', [True, False])
|
||||||
@parameterize('enable_tensor_parallelism', [True, False])
|
@parameterize('enable_tensor_parallelism', [True, False])
|
||||||
def run_t5_test(enable_fused_normalization, enable_tensor_parallelism):
|
@parameterize('use_lazy_init', [False, True])
|
||||||
|
def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
|
||||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_opt')
|
sub_model_zoo = model_zoo.get_sub_registry('transformers_opt')
|
||||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||||
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism)
|
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
|
||||||
|
use_lazy_init)
|
||||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
@ -82,10 +82,12 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
|
|||||||
|
|
||||||
@parameterize('enable_fused_normalization', [True, False])
|
@parameterize('enable_fused_normalization', [True, False])
|
||||||
@parameterize('enable_tensor_parallelism', [True, False])
|
@parameterize('enable_tensor_parallelism', [True, False])
|
||||||
def run_t5_test(enable_fused_normalization, enable_tensor_parallelism):
|
@parameterize('use_lazy_init', [False, True])
|
||||||
|
def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
|
||||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_t5')
|
sub_model_zoo = model_zoo.get_sub_registry('transformers_t5')
|
||||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||||
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism)
|
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
|
||||||
|
use_lazy_init)
|
||||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
from contextlib import nullcontext
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
@ -5,15 +7,15 @@ from torch.nn.parallel import DistributedDataParallel as DDP
|
|||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.cluster import DistCoordinator
|
from colossalai.cluster import DistCoordinator
|
||||||
|
from colossalai.lazy import LazyInitContext
|
||||||
from colossalai.logging import disable_existing_loggers
|
from colossalai.logging import disable_existing_loggers
|
||||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||||
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
|
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||||
from tests.kit.model_zoo import model_zoo
|
from tests.kit.model_zoo import model_zoo
|
||||||
|
|
||||||
|
|
||||||
def check_shardformer_with_ddp(rank, world_size, port):
|
@parameterize('lazy_init', [True, False])
|
||||||
disable_existing_loggers()
|
def check_shardformer_with_ddp(lazy_init: bool):
|
||||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
|
||||||
|
|
||||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
|
sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
|
||||||
|
|
||||||
@ -41,9 +43,12 @@ def check_shardformer_with_ddp(rank, world_size, port):
|
|||||||
shard_config = ShardConfig(tensor_parallel_process_group=tp_process_group, enable_fused_normalization=True)
|
shard_config = ShardConfig(tensor_parallel_process_group=tp_process_group, enable_fused_normalization=True)
|
||||||
shardformer = ShardFormer(shard_config=shard_config)
|
shardformer = ShardFormer(shard_config=shard_config)
|
||||||
|
|
||||||
|
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||||
|
|
||||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||||
# create and shard model
|
# create and shard model
|
||||||
model = model_fn().cuda()
|
with ctx:
|
||||||
|
model = model_fn().cuda()
|
||||||
sharded_model, _ = shardformer.optimize(model)
|
sharded_model, _ = shardformer.optimize(model)
|
||||||
|
|
||||||
# add ddp
|
# add ddp
|
||||||
@ -65,13 +70,18 @@ def check_shardformer_with_ddp(rank, world_size, port):
|
|||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
def run_dist(rank, world_size, port):
|
||||||
|
disable_existing_loggers()
|
||||||
|
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
|
check_shardformer_with_ddp()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
@clear_cache_before_run()
|
@clear_cache_before_run()
|
||||||
def test_gpt2():
|
def test_gpt2():
|
||||||
spawn(check_shardformer_with_ddp, 4)
|
spawn(run_dist, 4)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_gpt2()
|
test_gpt2()
|
||||||
test_gpt2()
|
|
||||||
|
Loading…
Reference in New Issue
Block a user