[zero] fix init bugs in zero context (#686)

* adapt model weight initialization for methods in Pytorch nn.init
This commit is contained in:
HELSON
2022-04-07 17:38:45 +08:00
committed by GitHub
parent 0ed7042f42
commit d7ecaf362b
8 changed files with 117 additions and 86 deletions

View File

@@ -60,8 +60,7 @@ def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, g
with ZeroInitContext(
target_device=torch.device('cpu') if cpu_offload else torch.device(f'cuda:{get_current_device()}'),
shard_strategy=shard_strategy,
shard_param=True,
rm_torch_payload_on_the_fly=False):
shard_param=True):
zero_model = MoeModel()
zero_model = ShardedModelV2(