[shardformer] support inplace sharding (#4251)

* [shardformer] embedding support inplace sharding

* [shardformer] linear support inplace sharding

* [shardformer] layernorm support inplace sharding

* [shardformer] qkv support inplace sharding

* [test] update shardformer layer test

* [shardformer] fix shared param sharding

* [shardformer] fix bert policy

* [shardformer] fix bloom policy

* [shardformer] fix llama policy

* [shardformer] fix opt policy

* [shardformer] fix t5 policy

* [shardformer] fix fused qkv linear

* [shardformer] fix bugs

* force sync

* [test] fix bugs

* [test] fix transformer version
This commit is contained in:
Hongxin Liu
2023-07-20 10:39:06 +08:00
parent 2a2eacfaf1
commit d921ce8391
26 changed files with 371 additions and 340 deletions

View File

@@ -15,11 +15,13 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
def check_embedding_1d(lazy_init: bool):
ctx = LazyInitContext() if lazy_init else nullcontext()
embedding = nn.Embedding(32, 128).cuda()
with ctx:
embedding = nn.Embedding(32, 128).cuda()
embedding_1d = Embedding1D.from_native_module(embedding, process_group=None)
embedding_copy = nn.Embedding(32, 128).cuda()
embedding_1d = Embedding1D.from_native_module(embedding_copy, process_group=None)
assert embedding_1d.weight.shape == torch.Size([32, 64])
assert embedding_1d.weight is embedding_copy.weight
# ensure state dict is reversibly loadable
embedding.load_state_dict(embedding_1d.state_dict())