mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
[hotfix] fix bugs for unsharded parameters when restore data (#664)
This commit is contained in:
@@ -264,8 +264,14 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|||||||
reuse_fp16_shard = p.colo_attr.saved_grad.data_ptr() == p.colo_attr.sharded_data_tensor.data_ptr()
|
reuse_fp16_shard = p.colo_attr.saved_grad.data_ptr() == p.colo_attr.sharded_data_tensor.data_ptr()
|
||||||
p.colo_attr.saved_grad.set_null()
|
p.colo_attr.saved_grad.set_null()
|
||||||
if recover_data and reuse_fp16_shard:
|
if recover_data and reuse_fp16_shard:
|
||||||
|
# We should write like this to trigger ForceFP32Paramter's half method
|
||||||
|
p.data = self.master_params[p].payload
|
||||||
p.colo_attr.sharded_data_tensor.reset_payload(
|
p.colo_attr.sharded_data_tensor.reset_payload(
|
||||||
colo_model_tensor_clone(self.master_params[p].payload.half(), torch.cuda.current_device()))
|
colo_model_tensor_clone(p.half(), torch.cuda.current_device()))
|
||||||
|
|
||||||
|
if not p.colo_attr.param_is_sharded:
|
||||||
|
# FIXME(hhc): add hook for unsharded parameters
|
||||||
|
p.data = p.colo_attr.sharded_data_tensor.payload
|
||||||
|
|
||||||
def sync_grad(self):
|
def sync_grad(self):
|
||||||
pass
|
pass
|
||||||
@@ -281,7 +287,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|||||||
# As we only store param shard, we shard it here
|
# As we only store param shard, we shard it here
|
||||||
self.shard_strategy.shard([p.colo_attr.sharded_data_tensor], self.dp_process_group)
|
self.shard_strategy.shard([p.colo_attr.sharded_data_tensor], self.dp_process_group)
|
||||||
self.master_params[p] = StatefulTensor(
|
self.master_params[p] = StatefulTensor(
|
||||||
cast_tensor_to_fp32(p.colo_attr.sharded_data_tensor.payload).to(self.device))
|
cast_tensor_to_fp32(p.colo_attr.sharded_data_tensor.payload.to(self.device)))
|
||||||
if not is_param_sharded and not self.keep_unshard:
|
if not is_param_sharded and not self.keep_unshard:
|
||||||
# In this branch, there's no need to shard param
|
# In this branch, there's no need to shard param
|
||||||
# So we gather here
|
# So we gather here
|
||||||
|
Reference in New Issue
Block a user