[shardformer] support inplace sharding (#4251)

* [shardformer] embedding support inplace sharding

* [shardformer] linear support inplace sharding

* [shardformer] layernorm support inplace sharding

* [shardformer] qkv support inplace sharding

* [test] update shardformer layer test

* [shardformer] fix shared param sharding

* [shardformer] fix bert policy

* [shardformer] fix bloom policy

* [shardformer] fix llama policy

* [shardformer] fix opt policy

* [shardformer] fix t5 policy

* [shardformer] fix fused qkv linear

* [shardformer] fix bugs

* force sync

* [test] fix bugs

* [test] fix transformer version
This commit is contained in:
Hongxin Liu
2023-07-20 10:39:06 +08:00
parent 2a2eacfaf1
commit d921ce8391
26 changed files with 371 additions and 340 deletions

View File

@@ -14,11 +14,14 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
def check_layernorm(lazy_init: bool):
ctx = LazyInitContext() if lazy_init else nullcontext()
norm = nn.LayerNorm(128, 0.00001).cuda()
with ctx:
norm = nn.LayerNorm(128, 0.00001).cuda()
norm1d = FusedLayerNorm.from_native_module(norm, process_group=None)
norm_copy = nn.LayerNorm(128, 0.00001).cuda()
norm1d = FusedLayerNorm.from_native_module(norm_copy, process_group=None)
assert norm1d.weight.shape == torch.Size([128])
assert norm_copy.weight is norm1d.weight
assert norm_copy.bias is norm1d.bias
# ensure state dict is reversibly loadable
norm.load_state_dict(norm1d.state_dict())