[hotfix] shared model returns cpu state_dict (#1328)

This commit is contained in:
ver217
2022-07-15 22:11:37 +08:00
committed by GitHub
parent b2475d8c5c
commit 7a05367101
2 changed files with 3 additions and 2 deletions

View File

@@ -39,7 +39,7 @@ def run_zero_state_dict(shard_strategy_class):
zero_state_dict = zero_model.state_dict()
for key, val in model.state_dict().items():
assert torch.equal(val, zero_state_dict[key])
assert torch.equal(val, zero_state_dict[key].to(val.device))
def run_dist(rank, world_size, port):