mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-18 07:31:19 +00:00
[hotfix] shared model returns cpu state_dict (#1328)
This commit is contained in:
@@ -39,7 +39,7 @@ def run_zero_state_dict(shard_strategy_class):
|
||||
|
||||
zero_state_dict = zero_model.state_dict()
|
||||
for key, val in model.state_dict().items():
|
||||
assert torch.equal(val, zero_state_dict[key])
|
||||
assert torch.equal(val, zero_state_dict[key].to(val.device))
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
|
Reference in New Issue
Block a user