[shardformer] support lazy init (#4202)

* [shardformer] support lazy init

* [shardformer] linear support lazy init

* [shardformer] embedding support lazy init

* [shardformer] norm support lazy init

* [shardformer] fused linear support lazy init

* [test] update shardformer test layer

* [test] shardformer with lazy init fit ddp

* [lazy] hotfix deepcopy of param

* [shardformer] fix bert policy and update test

* [shardformer] fix bloom policy and update test

* [shardformer] fix opt policy and update test

* [shardformer] fix t5 policy and update test

* [shardformer] fix gpt2 policy and update test

* [shardformer] fix llama policy and update test
This commit is contained in:
Hongxin Liu
2023-07-10 10:48:53 +08:00
parent f3bcc292c8
commit 890774b2fb
25 changed files with 263 additions and 157 deletions

View File

@@ -4,6 +4,8 @@
import torch
import torch.nn as nn
from colossalai.lazy import LazyInitContext
__all__ = ['FusedLayerNorm', 'FusedRMSNorm']
FAST_LAYERNORM_SUPPORTED_SIZE = [
@@ -35,6 +37,7 @@ class FusedLayerNorm():
raise ImportError(
'Please install apex from source (https://github.com/NVIDIA/apex) to use the fused layernorm kernel')
LazyInitContext.materialize(module)
# get the attributes of the module
normalized_shape = module.normalized_shape
eps = module.eps
@@ -84,6 +87,7 @@ class FusedRMSNorm():
'Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMS normalization kernel'
)
LazyInitContext.materialize(module)
# to check if it is huggingface LlamaRMSNorm
if module.__class__.__name__ == "LlamaRMSNorm":
normalized_shape = module.weight.shape[0]