[test] polish zero related unitest (#351)

This commit is contained in:
Jiarui Fang
2022-03-10 09:57:26 +08:00
committed by Frank Lee
parent 534e0bb118
commit cb34cd384d
5 changed files with 75 additions and 123 deletions

View File

@@ -0,0 +1,19 @@
import torch
from colossalai.zero.sharded_model import ShardedModelV2
import copy
def col_model_deepcopy(sharded_model: ShardedModelV2, other_model: torch.nn.Module):
"""
copy param of the ShardedModelV2 to other_model.
Note the other_model has to be the same as self.
"""
for zero_param, param in zip(sharded_model.parameters(), other_model.parameters()):
assert hasattr(zero_param, 'col_attr')
shard_flag = zero_param.col_attr.data.is_sharded
if shard_flag:
sharded_model.shard_strategy.gather([zero_param.col_attr.data])
param.data = copy.deepcopy(zero_param.col_attr.data.payload)
if shard_flag:
sharded_model.shard_strategy.shard([zero_param.col_attr.data])