mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 19:40:28 +00:00
[shardformer] support inplace sharding (#4251)
* [shardformer] embedding support inplace sharding * [shardformer] linear support inplace sharding * [shardformer] layernorm support inplace sharding * [shardformer] qkv support inplace sharding * [test] update shardformer layer test * [shardformer] fix shared param sharding * [shardformer] fix bert policy * [shardformer] fix bloom policy * [shardformer] fix llama policy * [shardformer] fix opt policy * [shardformer] fix t5 policy * [shardformer] fix fused qkv linear * [shardformer] fix bugs * force sync * [test] fix bugs * [test] fix transformer version
This commit is contained in:
@@ -15,11 +15,13 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
def check_embedding_1d(lazy_init: bool):
|
||||
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||
|
||||
embedding = nn.Embedding(32, 128).cuda()
|
||||
with ctx:
|
||||
embedding = nn.Embedding(32, 128).cuda()
|
||||
embedding_1d = Embedding1D.from_native_module(embedding, process_group=None)
|
||||
embedding_copy = nn.Embedding(32, 128).cuda()
|
||||
embedding_1d = Embedding1D.from_native_module(embedding_copy, process_group=None)
|
||||
|
||||
assert embedding_1d.weight.shape == torch.Size([32, 64])
|
||||
assert embedding_1d.weight is embedding_copy.weight
|
||||
|
||||
# ensure state dict is reversibly loadable
|
||||
embedding.load_state_dict(embedding_1d.state_dict())
|
||||
|
Reference in New Issue
Block a user