mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-22 09:59:38 +00:00
[zero] hijack p.grad in sharded model (#554)
* hijack p.grad in sharded model * polish comments * polish comments
This commit is contained in:
@@ -9,7 +9,9 @@ from colossalai.zero.shard_utils import BaseShardStrategy
|
||||
from colossalai.zero.sharded_param.tensorful_state import TensorState
|
||||
|
||||
from ._base_ophook import BaseOpHook
|
||||
from colossalai.utils.memory_utils.utils import colo_model_data_tensor_move_inline
|
||||
|
||||
from colossalai.utils.memory_utils.utils import \
|
||||
colo_model_data_tensor_move_inline
|
||||
|
||||
|
||||
@OPHOOKS.register_module
|
||||
@@ -67,21 +69,6 @@ class ZeroHook(BaseOpHook):
|
||||
for param in module.parameters(recurse=False):
|
||||
colo_model_data_tensor_move_inline(param.col_attr.sharded_data_tensor, self.computing_device)
|
||||
param.data = param.col_attr.sharded_data_tensor.payload
|
||||
# Store local accumulated grad shard
|
||||
if param.grad is not None:
|
||||
if param.col_attr.bwd_count == 0:
|
||||
# We haven't stored local accumulated grad yet
|
||||
assert param.col_attr.fp32_grad.is_null()
|
||||
|
||||
# Allocate grad fp32 memory space here
|
||||
param.col_attr.fp32_grad.reset_payload(param.grad.data)
|
||||
# TODO(jiaruifang) we should set grad fp16 state to HOLD here.
|
||||
param.grad = None
|
||||
else:
|
||||
# We have stored local accumulated grad
|
||||
# The grad here must be locally computed full grad in this backward pass
|
||||
assert param.grad.shape == param.col_attr.sharded_data_tensor.origin_shape
|
||||
param.col_attr.bwd_count += 1
|
||||
if self._memstarts_collector:
|
||||
self._memstarts_collector.sample_memstats()
|
||||
|
||||
|
Reference in New Issue
Block a user