mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 11:32:10 +00:00
[shardformer] support module saving and loading (#4062)
* [shardformer] support module saving and loading * polish code
This commit is contained in:
@@ -14,6 +14,10 @@ def check_embedding_1d():
|
||||
|
||||
assert embedding_1d.weight.shape == torch.Size([32, 64])
|
||||
|
||||
# ensure state dict is reversibly loadable
|
||||
embedding.load_state_dict(embedding_1d.state_dict())
|
||||
embedding_1d.load_state_dict(embedding.state_dict())
|
||||
|
||||
# check computation correctness
|
||||
x = torch.randint(low=0, high=32, size=(4, 32)).cuda()
|
||||
out = embedding(x)
|
||||
|
Reference in New Issue
Block a user