mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 12:01:39 +00:00
[shardformer] support lazy init (#4202)
* [shardformer] support lazy init * [shardformer] linear support lazy init * [shardformer] embedding support lazy init * [shardformer] norm support lazy init * [shardformer] fused linear support lazy init * [test] update shardformer test layer * [test] shardformer with lazy init fit ddp * [lazy] hotfix deepcopy of param * [shardformer] fix bert policy and update test * [shardformer] fix bloom policy and update test * [shardformer] fix opt policy and update test * [shardformer] fix t5 policy and update test * [shardformer] fix gpt2 policy and update test * [shardformer] fix llama policy and update test
This commit is contained in:
@@ -4,6 +4,8 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.lazy import LazyInitContext
|
||||
|
||||
__all__ = ['FusedLayerNorm', 'FusedRMSNorm']
|
||||
|
||||
FAST_LAYERNORM_SUPPORTED_SIZE = [
|
||||
@@ -35,6 +37,7 @@ class FusedLayerNorm():
|
||||
raise ImportError(
|
||||
'Please install apex from source (https://github.com/NVIDIA/apex) to use the fused layernorm kernel')
|
||||
|
||||
LazyInitContext.materialize(module)
|
||||
# get the attributes of the module
|
||||
normalized_shape = module.normalized_shape
|
||||
eps = module.eps
|
||||
@@ -84,6 +87,7 @@ class FusedRMSNorm():
|
||||
'Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMS normalization kernel'
|
||||
)
|
||||
|
||||
LazyInitContext.materialize(module)
|
||||
# to check if it is huggingface LlamaRMSNorm
|
||||
if module.__class__.__name__ == "LlamaRMSNorm":
|
||||
normalized_shape = module.weight.shape[0]
|
||||
|
Reference in New Issue
Block a user