[shardformer] support module saving and loading (#4062)

* [shardformer] support module saving and loading

* polish code
This commit is contained in:
Frank Lee
2023-06-22 11:42:11 +08:00
parent 7740c55c55
commit 8eb09a4c69
19 changed files with 493 additions and 102 deletions

View File

@@ -14,6 +14,10 @@ def check_embedding_1d():
assert embedding_1d.weight.shape == torch.Size([32, 64])
# ensure state dict is reversibly loadable
embedding.load_state_dict(embedding_1d.state_dict())
embedding_1d.load_state_dict(embedding.state_dict())
# check computation correctness
x = torch.randint(low=0, high=32, size=(4, 32)).cuda()
out = embedding(x)