[shardformer] support lazy init (#4202)

* [shardformer] support lazy init

* [shardformer] linear support lazy init

* [shardformer] embedding support lazy init

* [shardformer] norm support lazy init

* [shardformer] fused linear support lazy init

* [test] update shardformer test layer

* [test] shardformer with lazy init fit ddp

* [lazy] hotfix deepcopy of param

* [shardformer] fix bert policy and update test

* [shardformer] fix bloom policy and update test

* [shardformer] fix opt policy and update test

* [shardformer] fix t5 policy and update test

* [shardformer] fix gpt2 policy and update test

* [shardformer] fix llama policy and update test
This commit is contained in:
Hongxin Liu 2023-07-10 10:48:53 +08:00 committed by GitHub
parent c6f9c2c033
commit 0192011688
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 263 additions and 157 deletions

View File

@ -5,6 +5,7 @@ import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from torch import Tensor from torch import Tensor
from torch.nn import Parameter
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
from colossalai._analyzer._subclasses import MetaTensor from colossalai._analyzer._subclasses import MetaTensor
@ -95,8 +96,11 @@ def _convert_cls(tensor: 'LazyTensor', target: torch.Tensor) -> torch.Tensor:
Returns: Returns:
torch.Tensor: the converted tensor torch.Tensor: the converted tensor
""" """
cls_to_become = nn.Parameter if isinstance(tensor, nn.Parameter) else torch.Tensor cls_to_become = Parameter if isinstance(tensor, Parameter) else torch.Tensor
tensor.__class__ = cls_to_become tensor.__class__ = cls_to_become
if cls_to_become is Parameter:
# to fit UninitializedParameter
delattr(tensor, '_is_param')
tensor.data = target tensor.data = target
tensor.requires_grad = target.requires_grad tensor.requires_grad = target.requires_grad
# subclass of torch.Tensor does not have tolist() method # subclass of torch.Tensor does not have tolist() method
@ -190,10 +194,10 @@ class LazyTensor(torch.Tensor):
def clean(self) -> None: def clean(self) -> None:
"""Clean all stored operations, meta data and materialized data, which prevents memory leaking. This should be called after all tensors are materialized. """Clean all stored operations, meta data and materialized data, which prevents memory leaking. This should be called after all tensors are materialized.
""" """
self._factory_method = None delattr(self, '_factory_method')
self._op_buffer = None delattr(self, '_op_buffer')
self._materialized_data = None delattr(self, '_materialized_data')
self._meta_data = None delattr(self, '_meta_data')
@staticmethod @staticmethod
def _replace_with_materialized(x): def _replace_with_materialized(x):
@ -346,20 +350,19 @@ class LazyTensor(torch.Tensor):
def factory_fn(): def factory_fn():
# if self is materialized, return self # if self is materialized, return self
new_tensor = self.materialize() if type(self) is LazyTensor else self new_tensor = self.materialize() if type(self) is LazyTensor else self
copied = new_tensor.detach().clone() return _copy_tensor(new_tensor, new_tensor.requires_grad)
if new_tensor.requires_grad:
copied.requires_grad_()
return copied
if self._materialized_data is not None: if self._materialized_data is not None:
# self is early materialized # self is early materialized
copied = self._materialized_data.detach().clone() copied = _copy_tensor(self._materialized_data, self.requires_grad)
if self.requires_grad:
copied.requires_grad_()
target = LazyTensor(lambda: None, concrete_data=copied) target = LazyTensor(lambda: None, concrete_data=copied)
else: else:
target = LazyTensor(factory_fn, meta_data=self._meta_data) target = LazyTensor(factory_fn, meta_data=self._meta_data)
if isinstance(self, Parameter):
# hack isinstance check of parameter
target._is_param = True
memo[id(self)] = target memo[id(self)] = target
return target return target
@ -404,6 +407,10 @@ class LazyTensor(torch.Tensor):
def __hash__(self): def __hash__(self):
return id(self) return id(self)
def __rpow__(self, other):
dtype = torch.result_type(self, other)
return torch.tensor(other, dtype=dtype, device=self.device)**self
class LazyInitContext: class LazyInitContext:
"""Context manager for lazy initialization. Enables initializing the model without allocating real memory. """Context manager for lazy initialization. Enables initializing the model without allocating real memory.
@ -524,7 +531,7 @@ class LazyInitContext:
@staticmethod @staticmethod
def materialize(module: nn.Module, verbose: bool = False) -> nn.Module: def materialize(module: nn.Module, verbose: bool = False) -> nn.Module:
"""Initialize all ``nn.Parameter`` from ``LazyTensor``. This function will modify the module in-place. """Initialize all ``Parameter`` from ``LazyTensor``. This function will modify the module in-place.
Args: Args:
module (nn.Module): Target ``nn.Module`` module (nn.Module): Target ``nn.Module``
@ -541,7 +548,7 @@ class LazyInitContext:
device_mesh: DeviceMesh, device_mesh: DeviceMesh,
sharding_spec_dict: Dict[str, ShardingSpec], sharding_spec_dict: Dict[str, ShardingSpec],
verbose: bool = False) -> nn.Module: verbose: bool = False) -> nn.Module:
"""Distribute all ``nn.Parameter`` from ``LazyTensor``. This function will modify the module in-place. """Distribute all ``Parameter`` from ``LazyTensor``. This function will modify the module in-place.
Args: Args:
module (nn.Module): Target ``nn.Module`` module (nn.Module): Target ``nn.Module``
@ -613,3 +620,9 @@ def _is_int_tuple(args) -> bool:
if not isinstance(x, int): if not isinstance(x, int):
return False return False
return True return True
def _copy_tensor(tensor: Tensor, requires_grad: bool) -> Tensor:
copied = tensor.data.clone()
copied.requires_grad = requires_grad
return copied

View File

@ -9,8 +9,8 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch import Tensor from torch import Tensor
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
from colossalai.lazy import LazyInitContext
from colossalai.nn import init as init from colossalai.nn import init as init
from colossalai.nn.layer.utils import divide from colossalai.nn.layer.utils import divide
from colossalai.tensor.d_tensor.api import shard_colwise, shard_rowwise, sharded_tensor_to_param from colossalai.tensor.d_tensor.api import shard_colwise, shard_rowwise, sharded_tensor_to_param
@ -95,6 +95,7 @@ class Embedding1D(ParallelModule):
r""" r"""
Build a 1D parallelized Embedding from a native nn.Embedding module. Build a 1D parallelized Embedding from a native nn.Embedding module.
""" """
LazyInitContext.materialize(module)
# get the attributes # get the attributes
num_embedding = module.num_embeddings num_embedding = module.num_embeddings
embedding_dim = module.embedding_dim embedding_dim = module.embedding_dim
@ -223,6 +224,7 @@ class VocabParallelEmbedding1D(ParallelModule):
r""" r"""
Convert a native pytorch embedding module to a parallel module. Convert a native pytorch embedding module to a parallel module.
""" """
LazyInitContext.materialize(module)
# get the origin attributes # get the origin attributes
num_embeddings = module.num_embeddings num_embeddings = module.num_embeddings
embedding_dim = module.embedding_dim embedding_dim = module.embedding_dim
@ -243,6 +245,7 @@ class VocabParallelEmbedding1D(ParallelModule):
process_group=process_group, process_group=process_group,
*args, *args,
**kwargs) **kwargs)
with torch.no_grad(): with torch.no_grad():
# shard and slice the weight along the vocabulary(num_embeddings) dimension # shard and slice the weight along the vocabulary(num_embeddings) dimension
# the shape of the weight is (num_embeddings, embedding_dim) # the shape of the weight is (num_embeddings, embedding_dim)

View File

@ -12,6 +12,7 @@ from torch import Tensor
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from colossalai.lazy import LazyInitContext
from colossalai.nn import init as init from colossalai.nn import init as init
from colossalai.nn.layer.utils import divide from colossalai.nn.layer.utils import divide
from colossalai.tensor.d_tensor import shard_colwise, shard_rowwise, sharded_tensor_to_param from colossalai.tensor.d_tensor import shard_colwise, shard_rowwise, sharded_tensor_to_param
@ -106,6 +107,7 @@ class Linear1D_Col(ParallelModule):
r""" r"""
Convert a native PyTorch linear layer to a parallelized linear layer. Convert a native PyTorch linear layer to a parallelized linear layer.
""" """
LazyInitContext.materialize(module)
# get the attributes # get the attributes
in_features = module.in_features in_features = module.in_features
out_features = module.out_features out_features = module.out_features
@ -242,6 +244,7 @@ class Linear1D_Row(ParallelModule):
r""" r"""
Convert a native PyTorch linear layer to a parallelized linear layer. Convert a native PyTorch linear layer to a parallelized linear layer.
""" """
LazyInitContext.materialize(module)
# get the attributes # get the attributes
in_features = module.in_features in_features = module.in_features
out_features = module.out_features out_features = module.out_features

View File

@ -4,6 +4,8 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from colossalai.lazy import LazyInitContext
__all__ = ['FusedLayerNorm', 'FusedRMSNorm'] __all__ = ['FusedLayerNorm', 'FusedRMSNorm']
FAST_LAYERNORM_SUPPORTED_SIZE = [ FAST_LAYERNORM_SUPPORTED_SIZE = [
@ -35,6 +37,7 @@ class FusedLayerNorm():
raise ImportError( raise ImportError(
'Please install apex from source (https://github.com/NVIDIA/apex) to use the fused layernorm kernel') 'Please install apex from source (https://github.com/NVIDIA/apex) to use the fused layernorm kernel')
LazyInitContext.materialize(module)
# get the attributes of the module # get the attributes of the module
normalized_shape = module.normalized_shape normalized_shape = module.normalized_shape
eps = module.eps eps = module.eps
@ -84,6 +87,7 @@ class FusedRMSNorm():
'Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMS normalization kernel' 'Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMS normalization kernel'
) )
LazyInitContext.materialize(module)
# to check if it is huggingface LlamaRMSNorm # to check if it is huggingface LlamaRMSNorm
if module.__class__.__name__ == "LlamaRMSNorm": if module.__class__.__name__ == "LlamaRMSNorm":
normalized_shape = module.weight.shape[0] normalized_shape = module.weight.shape[0]

View File

@ -12,6 +12,7 @@ from torch import Tensor
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from colossalai.lazy import LazyInitContext
from colossalai.nn import init as init from colossalai.nn import init as init
from colossalai.nn.layer.utils import divide from colossalai.nn.layer.utils import divide
from colossalai.tensor.d_tensor.api import ( from colossalai.tensor.d_tensor.api import (
@ -231,6 +232,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
process_group (`Union[ProcessGroup, List[ProcessGroup]]`): The process group to be used for weight sharding and communication. process_group (`Union[ProcessGroup, List[ProcessGroup]]`): The process group to be used for weight sharding and communication.
n_fused (int): The number of layers to be fused. In GPT2, Q,K,V are fused in one weight. n_fused (int): The number of layers to be fused. In GPT2, Q,K,V are fused in one weight.
""" """
LazyInitContext.materialize(module)
# get the attributes # get the attributes
in_features = module.weight.shape[0] in_features = module.weight.shape[0]
out_features = module.weight.shape[1] out_features = module.weight.shape[1]
@ -380,6 +382,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
r""" r"""
Convert a native PyTorch linear layer to a parallelized linear layer. Convert a native PyTorch linear layer to a parallelized linear layer.
""" """
LazyInitContext.materialize(module)
# get the attributes # get the attributes
in_features = module.weight.shape[0] in_features = module.weight.shape[0]
out_features = module.weight.shape[1] out_features = module.weight.shape[1]
@ -428,9 +431,9 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
src_rank = dist.distributed_c10d._get_global_rank(self.process_group, 0) src_rank = dist.distributed_c10d._get_global_rank(self.process_group, 0)
origin_device = self.bias.device origin_device = self.bias.device
self.bias = self.bias.cuda() self.bias.data = self.bias.cuda()
dist.broadcast(self.bias, src=src_rank, group=self.process_group) dist.broadcast(self.bias, src=src_rank, group=self.process_group)
self.bias = self.bias.to(origin_device) self.bias.data = self.bias.to(origin_device)
def forward(self, input_: Tensor) -> Tensor: def forward(self, input_: Tensor) -> Tensor:
# Set up backprop all-reduce. # Set up backprop all-reduce.

View File

@ -46,11 +46,12 @@ class BertPolicy(Policy):
Reshape the Embedding layer to make the embedding dimension divisible by world_size Reshape the Embedding layer to make the embedding dimension divisible by world_size
""" """
# TODO: # TODO:
vocab_size = self.model.config.vocab_size if self.shard_config.enable_tensor_parallelism:
world_size = self.shard_config.tensor_parallel_size vocab_size = self.model.config.vocab_size
if vocab_size % world_size != 0: world_size = self.shard_config.tensor_parallel_size
new_vocab_size = vocab_size + world_size - vocab_size % world_size if vocab_size % world_size != 0:
self.model.resize_token_embeddings(new_vocab_size) new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
return self.model return self.model
def module_policy(self): def module_policy(self):
@ -229,10 +230,11 @@ class BertForPreTrainingPolicy(BertPolicy):
return [] return []
def postprocess(self): def postprocess(self):
binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"} if self.shard_config.enable_tensor_parallelism:
for k, v in binding_map.items(): binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"}
param = getattr_(self.model, k) for k, v in binding_map.items():
setattr_(self.model, v, param) param = getattr_(self.model, k)
setattr_(self.model, v, param)
return self.model return self.model
@ -269,10 +271,11 @@ class BertLMHeadModelPolicy(BertPolicy):
return [] return []
def postprocess(self): def postprocess(self):
binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"} if self.shard_config.enable_tensor_parallelism:
for k, v in binding_map.items(): binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"}
param = getattr_(self.model, k) for k, v in binding_map.items():
setattr_(self.model, v, param) param = getattr_(self.model, k)
setattr_(self.model, v, param)
return self.model return self.model
@ -288,10 +291,11 @@ class BertForMaskedLMPolicy(BertPolicy):
return module_policy return module_policy
def postprocess(self): def postprocess(self):
binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"} if self.shard_config.enable_tensor_parallelism:
for k, v in binding_map.items(): binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"}
param = getattr_(self.model, k) for k, v in binding_map.items():
setattr_(self.model, v, param) param = getattr_(self.model, k)
setattr_(self.model, v, param)
return self.model return self.model

View File

@ -17,11 +17,12 @@ class BloomPolicy(Policy):
r""" r"""
Reshape the Embedding layer to make the embedding dimension divisible by world_size Reshape the Embedding layer to make the embedding dimension divisible by world_size
""" """
vocab_size = self.model.config.vocab_size if self.shard_config.enable_tensor_parallelism:
world_size = self.shard_config.tensor_parallel_size vocab_size = self.model.config.vocab_size
if vocab_size % world_size != 0: world_size = self.shard_config.tensor_parallel_size
new_vocab_size = vocab_size + world_size - vocab_size % world_size if vocab_size % world_size != 0:
self.model.resize_token_embeddings(new_vocab_size) new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
return self.model return self.model
def module_policy(self): def module_policy(self):
@ -128,16 +129,13 @@ class BloomForCausalLMPolicy(BloomPolicy):
return policy return policy
def postprocess(self): def postprocess(self):
binding_map = {"transformer.word_embeddings.weight": "lm_head.weight"} if self.shard_config.enable_tensor_parallelism:
binding_map = {"transformer.word_embeddings.weight": "lm_head.weight"}
for k, v in binding_map.items(): for k, v in binding_map.items():
param = getattr_(self.model, k) param = getattr_(self.model, k)
# tie weights
if not isinstance(param, nn.Parameter): setattr_(self.model, v, param)
param = nn.Parameter(param)
# tie weights
setattr_(self.model, v, param)
return self.model return self.model

View File

@ -21,11 +21,12 @@ class GPT2Policy(Policy):
r""" r"""
Reshape the Embedding layer to make the embedding dimension divisible by world_size Reshape the Embedding layer to make the embedding dimension divisible by world_size
""" """
vocab_size = self.model.config.vocab_size if self.shard_config.enable_tensor_parallelism:
world_size = self.shard_config.tensor_parallel_size vocab_size = self.model.config.vocab_size
if vocab_size % world_size != 0: world_size = self.shard_config.tensor_parallel_size
new_vocab_size = vocab_size + world_size - vocab_size % world_size if vocab_size % world_size != 0:
self.model.resize_token_embeddings(new_vocab_size) new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
return self.model return self.model
def module_policy(self): def module_policy(self):
@ -142,10 +143,11 @@ class GPT2LMHeadModelPolicy(GPT2Policy):
return module_policy return module_policy
def postprocess(self): def postprocess(self):
binding_map = {"transformer.wte.weight": "lm_head.weight"} if self.shard_config.enable_tensor_parallelism:
for k, v in binding_map.items(): binding_map = {"transformer.wte.weight": "lm_head.weight"}
param = getattr_(self.model, k) for k, v in binding_map.items():
setattr_(self.model, v, param) param = getattr_(self.model, k)
setattr_(self.model, v, param)
return self.model return self.model
@ -172,10 +174,11 @@ class GPT2DoubleHeadsModelPolicy(GPT2Policy):
return module_policy return module_policy
def postprocess(self): def postprocess(self):
binding_map = {"transformer.wte.weight": "lm_head.weight"} if self.shard_config.enable_tensor_parallelism:
for k, v in binding_map.items(): binding_map = {"transformer.wte.weight": "lm_head.weight"}
param = getattr_(self.model, k) for k, v in binding_map.items():
setattr_(self.model, v, param) param = getattr_(self.model, k)
setattr_(self.model, v, param)
return self.model return self.model

View File

@ -15,13 +15,14 @@ class LlamaPolicy(Policy):
pass pass
def preprocess(self): def preprocess(self):
# Resize embedding if self.shard_config.enable_tensor_parallelism:
vocab_size = self.model.config.vocab_size # Resize embedding
world_size = self.shard_config.tensor_parallel_size vocab_size = self.model.config.vocab_size
world_size = self.shard_config.tensor_parallel_size
if vocab_size % world_size != 0: if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size) self.model.resize_token_embeddings(new_vocab_size)
return self.model return self.model

View File

@ -19,11 +19,12 @@ class OPTPolicy(Policy):
r""" r"""
Reshape the Embedding layer to make the embedding dimension divisible by world_size Reshape the Embedding layer to make the embedding dimension divisible by world_size
""" """
vocab_size = self.model.config.vocab_size if self.shard_config.enable_tensor_parallelism:
world_size = self.shard_config.tensor_parallel_size vocab_size = self.model.config.vocab_size
if vocab_size % world_size != 0: world_size = self.shard_config.tensor_parallel_size
new_vocab_size = vocab_size + world_size - vocab_size % world_size if vocab_size % world_size != 0:
self.model.resize_token_embeddings(new_vocab_size) new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
return self.model return self.model
def module_policy(self): def module_policy(self):
@ -116,14 +117,15 @@ class OPTForCausalLMPolicy(OPTPolicy):
return policy return policy
def postprocess(self): def postprocess(self):
binding_map = { if self.shard_config.enable_tensor_parallelism:
'model.decoder.embed_tokens': 'lm_head', binding_map = {
} 'model.decoder.embed_tokens': 'lm_head',
}
for k, v in binding_map.items(): for k, v in binding_map.items():
src_mod = getattr_(self.model, k) src_mod = getattr_(self.model, k)
dst_mod = getattr_(self.model, v) dst_mod = getattr_(self.model, v)
dst_mod.weight = src_mod.weight dst_mod.weight = src_mod.weight
return self.model return self.model

View File

@ -24,11 +24,12 @@ class T5BasePolicy(Policy):
r""" r"""
Reshape the Embedding layer to make the embedding dimension divisible by world_size Reshape the Embedding layer to make the embedding dimension divisible by world_size
""" """
vocab_size = self.model.config.vocab_size if self.shard_config.enable_tensor_parallelism:
world_size = self.shard_config.tensor_parallel_size vocab_size = self.model.config.vocab_size
if vocab_size % world_size != 0: world_size = self.shard_config.tensor_parallel_size
new_vocab_size = vocab_size + world_size - vocab_size % world_size if vocab_size % world_size != 0:
self.model.resize_token_embeddings(new_vocab_size) new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
return self.model return self.model
def module_policy(self): def module_policy(self):
@ -164,11 +165,12 @@ class T5BasePolicy(Policy):
return policy return policy
def postprocess(self): def postprocess(self):
binding_map = [["shared", "encoder.embed_tokens"], ["shared", "decoder.embed_tokens"]] if self.shard_config.enable_tensor_parallelism:
binding_map = [["shared", "encoder.embed_tokens"], ["shared", "decoder.embed_tokens"]]
for k, v in binding_map: for k, v in binding_map:
mod = getattr_(self.model, k) mod = getattr_(self.model, k)
setattr_(self.model, v, mod) setattr_(self.model, v, mod)
return self.model return self.model
@ -211,13 +213,13 @@ class T5ForConditionalGenerationPolicy(T5BasePolicy):
def postprocess(self): def postprocess(self):
super().postprocess() super().postprocess()
if self.shard_config.enable_tensor_parallelism:
binding_map = {"shared": "lm_head"}
binding_map = {"shared": "lm_head"} for k, v in binding_map.items():
src_mod = getattr_(self.model, k)
for k, v in binding_map.items(): dst_mod = getattr_(self.model, v)
src_mod = getattr_(self.model, k) dst_mod.weight = src_mod.weight
dst_mod = getattr_(self.model, v)
dst_mod.weight = src_mod.weight
return self.model return self.model
@ -239,11 +241,12 @@ class T5EncoderPolicy(T5BasePolicy):
return base_policy return base_policy
def postprocess(self): def postprocess(self):
binding_map = [ if self.shard_config.enable_tensor_parallelism:
["shared", "encoder.embed_tokens"], binding_map = [
] ["shared", "encoder.embed_tokens"],
]
for k, v in binding_map: for k, v in binding_map:
mod = getattr_(self.model, k) mod = getattr_(self.model, k)
setattr_(self.model, v, mod) setattr_(self.model, v, mod)
return self.model return self.model

View File

@ -3,7 +3,7 @@ from typing import Any, Callable, Dict, List, Union
import torch.nn as nn import torch.nn as nn
from torch import Tensor from torch import Tensor
from colossalai.lazy import LazyTensor from colossalai.lazy import LazyInitContext
from .._utils import getattr_, setattr_ from .._utils import getattr_, setattr_
from ..policies.auto_policy import get_autopolicy from ..policies.auto_policy import get_autopolicy
@ -192,10 +192,4 @@ class ModelSharder(object):
r""" r"""
Materialize the model if lazy initialization is used Materialize the model if lazy initialization is used
""" """
for p in self.model.parameters(): LazyInitContext.materialize(self.model)
if isinstance(p, LazyTensor):
p.materialize()
for b in self.model.buffers():
if isinstance(b, LazyTensor):
b.materialize()

View File

@ -1,15 +1,22 @@
from contextlib import nullcontext
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from torch.testing import assert_close from torch.testing import assert_close
import colossalai import colossalai
from colossalai.lazy import LazyInitContext
from colossalai.shardformer.layer import Embedding1D from colossalai.shardformer.layer import Embedding1D
from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
def check_embedding_1d(): @parameterize('lazy_init', [False, True])
embedding = nn.Embedding(32, 128).cuda() def check_embedding_1d(lazy_init: bool):
ctx = LazyInitContext() if lazy_init else nullcontext()
with ctx:
embedding = nn.Embedding(32, 128).cuda()
embedding_1d = Embedding1D.from_native_module(embedding, process_group=None) embedding_1d = Embedding1D.from_native_module(embedding, process_group=None)
assert embedding_1d.weight.shape == torch.Size([32, 64]) assert embedding_1d.weight.shape == torch.Size([32, 64])

View File

@ -1,14 +1,21 @@
from contextlib import nullcontext
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.testing import assert_close from torch.testing import assert_close
import colossalai import colossalai
from colossalai.lazy import LazyInitContext
from colossalai.shardformer.layer import FusedLayerNorm from colossalai.shardformer.layer import FusedLayerNorm
from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
def check_layernorm(): @parameterize('lazy_init', [False, True])
norm = nn.LayerNorm(128, 0.00001).cuda() def check_layernorm(lazy_init: bool):
ctx = LazyInitContext() if lazy_init else nullcontext()
with ctx:
norm = nn.LayerNorm(128, 0.00001).cuda()
norm1d = FusedLayerNorm.from_native_module(norm, process_group=None) norm1d = FusedLayerNorm.from_native_module(norm, process_group=None)
assert norm1d.weight.shape == torch.Size([128]) assert norm1d.weight.shape == torch.Size([128])

View File

@ -1,16 +1,23 @@
from contextlib import nullcontext
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from torch.testing import assert_close from torch.testing import assert_close
import colossalai import colossalai
from colossalai.lazy import LazyInitContext
from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row
from colossalai.tensor.d_tensor import is_distributed_tensor from colossalai.tensor.d_tensor import is_distributed_tensor
from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
def check_linear_1d_col(): @parameterize('lazy_init', [False, True])
linear = nn.Linear(32, 128).cuda() def check_linear_1d_col(lazy_init: bool):
ctx = LazyInitContext() if lazy_init else nullcontext()
with ctx:
linear = nn.Linear(32, 128).cuda()
linear_col = Linear1D_Col.from_native_module(linear, process_group=None, gather_output=True) linear_col = Linear1D_Col.from_native_module(linear, process_group=None, gather_output=True)
# ensure that the parameters are distributed # ensure that the parameters are distributed
@ -50,8 +57,12 @@ def check_linear_1d_col():
assert_close(x_for_unshard.grad, x_for_shard.grad) assert_close(x_for_unshard.grad, x_for_shard.grad)
def check_linear_1d_row(): @parameterize('lazy_init', [False, True])
linear = nn.Linear(32, 128).cuda() def check_linear_1d_row(lazy_init: bool):
ctx = LazyInitContext() if lazy_init else nullcontext()
with ctx:
linear = nn.Linear(32, 128).cuda()
linear_row = Linear1D_Row.from_native_module(linear, process_group=None, parallel_input=False) linear_row = Linear1D_Row.from_native_module(linear, process_group=None, parallel_input=False)
assert linear_row.weight.shape == torch.Size([128, 16]) assert linear_row.weight.shape == torch.Size([128, 16])
@ -83,9 +94,13 @@ def check_linear_1d_row():
assert_close(x_for_unshard.grad, x_for_shard.grad) assert_close(x_for_unshard.grad, x_for_shard.grad)
def check_linear_col_plus_row(): @parameterize('lazy_init', [False, True])
linear_1 = nn.Linear(32, 128).cuda() def check_linear_col_plus_row(lazy_init: bool):
linear_2 = nn.Linear(128, 32).cuda() ctx = LazyInitContext() if lazy_init else nullcontext()
with ctx:
linear_1 = nn.Linear(32, 128).cuda()
linear_2 = nn.Linear(128, 32).cuda()
linear_col = Linear1D_Col.from_native_module(linear_1, process_group=None, gather_output=False) linear_col = Linear1D_Col.from_native_module(linear_1, process_group=None, gather_output=False)
linear_row = Linear1D_Row.from_native_module(linear_2, process_group=None, parallel_input=True) linear_row = Linear1D_Row.from_native_module(linear_2, process_group=None, parallel_input=True)

View File

@ -1,12 +1,15 @@
from contextlib import nullcontext
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from torch.testing import assert_close from torch.testing import assert_close
import colossalai import colossalai
from colossalai.lazy import LazyInitContext
from colossalai.shardformer.layer import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row from colossalai.shardformer.layer import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
from colossalai.shardformer.layer.qkv_fused_linear import split_fused_qkv_in_gpt2_style from colossalai.shardformer.layer.qkv_fused_linear import split_fused_qkv_in_gpt2_style
from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
# This code is copied from https://github.com/huggingface/transformers # This code is copied from https://github.com/huggingface/transformers
@ -50,8 +53,12 @@ def rearrange(tensor: torch.Tensor, dim: int):
return rearanged_tensor return rearanged_tensor
def check_linear_conv_1d_col(): @parameterize('lazy_init', [False, True])
linear = Conv1D(192, 48).cuda() def check_linear_conv_1d_col(lazy_init: bool):
ctx = LazyInitContext() if lazy_init else nullcontext()
with ctx:
linear = Conv1D(192, 48).cuda()
linear_conv_col = GPT2FusedLinearConv1D_Col.from_native_module(linear, linear_conv_col = GPT2FusedLinearConv1D_Col.from_native_module(linear,
process_group=None, process_group=None,
gather_output=True, gather_output=True,
@ -80,8 +87,12 @@ def check_linear_conv_1d_col():
assert_close(target_grad, linear_conv_col.weight.grad) assert_close(target_grad, linear_conv_col.weight.grad)
def check_linear_conv_1d_row(): @parameterize('lazy_init', [False, True])
linear = Conv1D(192, 48).cuda() def check_linear_conv_1d_row(lazy_init: bool):
ctx = LazyInitContext() if lazy_init else nullcontext()
with ctx:
linear = Conv1D(192, 48).cuda()
linear_row = GPT2FusedLinearConv1D_Row.from_native_module(linear, process_group=None, parallel_input=False) linear_row = GPT2FusedLinearConv1D_Row.from_native_module(linear, process_group=None, parallel_input=False)
assert linear.weight.shape == torch.Size([48, 192]) assert linear.weight.shape == torch.Size([48, 192])

View File

@ -1,15 +1,23 @@
from contextlib import nullcontext
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from torch.testing import assert_close from torch.testing import assert_close
import colossalai import colossalai
from colossalai.shardformer.layer import VocabParallelEmbedding1D from colossalai.lazy import LazyInitContext
from colossalai.shardformer.layer import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row, VocabParallelEmbedding1D
from colossalai.shardformer.layer.qkv_fused_linear import split_fused_qkv_in_gpt2_style
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
def check_vocab_embedding_1d(): @parameterize('lazy_init', [False, True])
embedding = nn.Embedding(128, 32).to('cuda') def check_vocab_embedding_1d(lazy_init: bool):
ctx = LazyInitContext() if lazy_init else nullcontext()
with ctx:
embedding = nn.Embedding(128, 32).to('cuda')
dist_embedding_1d = VocabParallelEmbedding1D.from_native_module(embedding, process_group=None) dist_embedding_1d = VocabParallelEmbedding1D.from_native_module(embedding, process_group=None)
assert dist_embedding_1d.weight.shape == torch.Size([64, 32]) assert dist_embedding_1d.weight.shape == torch.Size([64, 32])

View File

@ -1,19 +1,24 @@
import copy import copy
from contextlib import nullcontext
from colossalai.lazy import LazyInitContext
from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.shardformer import ShardConfig, ShardFormer
def build_model(model_fn, enable_fused_normalization=True, enable_tensor_parallelism=True): def build_model(model_fn, enable_fused_normalization=True, enable_tensor_parallelism=True, use_lazy_init: bool = False):
# create new model ctx = LazyInitContext() if use_lazy_init else nullcontext()
org_model = model_fn().cuda() with ctx:
# create new model
org_model = model_fn()
model_copy = copy.deepcopy(org_model)
if use_lazy_init:
ctx.materialize(org_model)
# shard model # shard model
shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization, shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization,
enable_tensor_parallelism=enable_tensor_parallelism) enable_tensor_parallelism=enable_tensor_parallelism)
model_copy = copy.deepcopy(org_model)
shard_former = ShardFormer(shard_config=shard_config) shard_former = ShardFormer(shard_config=shard_config)
sharded_model, shared_params = shard_former.optimize(model_copy) sharded_model, shared_params = shard_former.optimize(model_copy)
return org_model, sharded_model.cuda() return org_model.cuda(), sharded_model.cuda()
def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):

View File

@ -67,12 +67,14 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"
@parameterize('enable_fused_normalization', [True, False]) @parameterize('enable_fused_normalization', [False, True])
@parameterize('enable_tensor_parallelism', [True, False]) @parameterize('enable_tensor_parallelism', [False, True])
def run_bert_test(enable_fused_normalization, enable_tensor_parallelism): @parameterize('use_lazy_init', [False, True])
def run_bert_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
sub_model_zoo = model_zoo.get_sub_registry('transformers_bert') sub_model_zoo = model_zoo.get_sub_registry('transformers_bert')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism) org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
use_lazy_init)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
torch.cuda.empty_cache() torch.cuda.empty_cache()

View File

@ -69,10 +69,12 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
@parameterize('enable_fused_normalization', [True, False]) @parameterize('enable_fused_normalization', [True, False])
@parameterize('enable_tensor_parallelism', [True, False]) @parameterize('enable_tensor_parallelism', [True, False])
def run_bloom_test(enable_fused_normalization, enable_tensor_parallelism): @parameterize('use_lazy_init', [False, True])
def run_bloom_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom') sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism) org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
use_lazy_init)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
torch.cuda.empty_cache() torch.cuda.empty_cache()

View File

@ -69,10 +69,12 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
@parameterize('enable_fused_normalization', [True, False]) @parameterize('enable_fused_normalization', [True, False])
@parameterize('enable_tensor_parallelism', [True, False]) @parameterize('enable_tensor_parallelism', [True, False])
def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism): @parameterize('use_lazy_init', [False, True])
def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt') sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism) org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
use_lazy_init)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
torch.cuda.empty_cache() torch.cuda.empty_cache()

View File

@ -72,10 +72,12 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
@parameterize('enable_fused_normalization', [True, False]) @parameterize('enable_fused_normalization', [True, False])
@parameterize('enable_tensor_parallelism', [True, False]) @parameterize('enable_tensor_parallelism', [True, False])
def run_gpt2_llama(enable_fused_normalization, enable_tensor_parallelism): @parameterize('use_lazy_init', [False, True])
def run_gpt2_llama(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
sub_model_zoo = model_zoo.get_sub_registry('transformers_llama') sub_model_zoo = model_zoo.get_sub_registry('transformers_llama')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism) org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
use_lazy_init)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
torch.cuda.empty_cache() torch.cuda.empty_cache()

View File

@ -71,10 +71,12 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
@parameterize('enable_fused_normalization', [True, False]) @parameterize('enable_fused_normalization', [True, False])
@parameterize('enable_tensor_parallelism', [True, False]) @parameterize('enable_tensor_parallelism', [True, False])
def run_t5_test(enable_fused_normalization, enable_tensor_parallelism): @parameterize('use_lazy_init', [False, True])
def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
sub_model_zoo = model_zoo.get_sub_registry('transformers_opt') sub_model_zoo = model_zoo.get_sub_registry('transformers_opt')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism) org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
use_lazy_init)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
torch.cuda.empty_cache() torch.cuda.empty_cache()

View File

@ -82,10 +82,12 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
@parameterize('enable_fused_normalization', [True, False]) @parameterize('enable_fused_normalization', [True, False])
@parameterize('enable_tensor_parallelism', [True, False]) @parameterize('enable_tensor_parallelism', [True, False])
def run_t5_test(enable_fused_normalization, enable_tensor_parallelism): @parameterize('use_lazy_init', [False, True])
def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
sub_model_zoo = model_zoo.get_sub_registry('transformers_t5') sub_model_zoo = model_zoo.get_sub_registry('transformers_t5')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism) org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
use_lazy_init)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
torch.cuda.empty_cache() torch.cuda.empty_cache()

View File

@ -1,3 +1,5 @@
from contextlib import nullcontext
import pytest import pytest
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -5,15 +7,15 @@ from torch.nn.parallel import DistributedDataParallel as DDP
import colossalai import colossalai
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.lazy import LazyInitContext
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo from tests.kit.model_zoo import model_zoo
def check_shardformer_with_ddp(rank, world_size, port): @parameterize('lazy_init', [True, False])
disable_existing_loggers() def check_shardformer_with_ddp(lazy_init: bool):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt') sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
@ -41,9 +43,12 @@ def check_shardformer_with_ddp(rank, world_size, port):
shard_config = ShardConfig(tensor_parallel_process_group=tp_process_group, enable_fused_normalization=True) shard_config = ShardConfig(tensor_parallel_process_group=tp_process_group, enable_fused_normalization=True)
shardformer = ShardFormer(shard_config=shard_config) shardformer = ShardFormer(shard_config=shard_config)
ctx = LazyInitContext() if lazy_init else nullcontext()
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
# create and shard model # create and shard model
model = model_fn().cuda() with ctx:
model = model_fn().cuda()
sharded_model, _ = shardformer.optimize(model) sharded_model, _ = shardformer.optimize(model)
# add ddp # add ddp
@ -65,13 +70,18 @@ def check_shardformer_with_ddp(rank, world_size, port):
torch.cuda.empty_cache() torch.cuda.empty_cache()
def run_dist(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
check_shardformer_with_ddp()
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
@clear_cache_before_run() @clear_cache_before_run()
def test_gpt2(): def test_gpt2():
spawn(check_shardformer_with_ddp, 4) spawn(run_dist, 4)
if __name__ == "__main__": if __name__ == "__main__":
test_gpt2() test_gpt2()
test_gpt2()