mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-15 22:19:38 +00:00
[shardformer] support inplace sharding (#4251)
* [shardformer] embedding support inplace sharding * [shardformer] linear support inplace sharding * [shardformer] layernorm support inplace sharding * [shardformer] qkv support inplace sharding * [test] update shardformer layer test * [shardformer] fix shared param sharding * [shardformer] fix bert policy * [shardformer] fix bloom policy * [shardformer] fix llama policy * [shardformer] fix opt policy * [shardformer] fix t5 policy * [shardformer] fix fused qkv linear * [shardformer] fix bugs * force sync * [test] fix bugs * [test] fix transformer version
This commit is contained in:
@@ -14,11 +14,14 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
def check_layernorm(lazy_init: bool):
|
||||
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||
|
||||
norm = nn.LayerNorm(128, 0.00001).cuda()
|
||||
with ctx:
|
||||
norm = nn.LayerNorm(128, 0.00001).cuda()
|
||||
norm1d = FusedLayerNorm.from_native_module(norm, process_group=None)
|
||||
norm_copy = nn.LayerNorm(128, 0.00001).cuda()
|
||||
norm1d = FusedLayerNorm.from_native_module(norm_copy, process_group=None)
|
||||
|
||||
assert norm1d.weight.shape == torch.Size([128])
|
||||
assert norm_copy.weight is norm1d.weight
|
||||
assert norm_copy.bias is norm1d.bias
|
||||
|
||||
# ensure state dict is reversibly loadable
|
||||
norm.load_state_dict(norm1d.state_dict())
|
||||
|
Reference in New Issue
Block a user