diff --git a/colossalai/zero/sharded_model/sharded_model_v2.py b/colossalai/zero/sharded_model/sharded_model_v2.py index 4f0297d34..70e14548b 100644 --- a/colossalai/zero/sharded_model/sharded_model_v2.py +++ b/colossalai/zero/sharded_model/sharded_model_v2.py @@ -89,8 +89,8 @@ class ShardedModelV2(nn.Module): self._iter_cnter = 0 # Register hooks - register_ophooks_recursively(self.module, - [ZeroHook(self.shard_strategy, self._memstats_collector, self.process_group)]) + self._ophook_list = [ZeroHook(self.shard_strategy, self._memstats_collector, self.process_group)] + register_ophooks_recursively(self.module, self._ophook_list) self.param_hook_mgr = BaseParamHookMgr(list(self.module.parameters())) self.param_hook_mgr.register_backward_hooks(self._grad_post_backward_hook) @@ -134,10 +134,14 @@ class ShardedModelV2(nn.Module): def backward(self, loss): loss.backward() self._post_backward_operations() + for ophook in self._ophook_list: + ophook.post_iter() def backward_by_grad(self, tensor, grad): torch.autograd.backward(tensors=tensor, grad_tensors=grad) self._post_backward_operations() + for ophook in self._ophook_list: + ophook.post_iter() @torch.no_grad() def _post_backward_operations(self) -> None: