mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-05-29 10:24:46 +00:00
[zero] sharded model manages ophooks individually (#492)
This commit is contained in:
parent
c9023d4078
commit
c4c02424f3
@ -89,8 +89,8 @@ class ShardedModelV2(nn.Module):
|
|||||||
self._iter_cnter = 0
|
self._iter_cnter = 0
|
||||||
|
|
||||||
# Register hooks
|
# Register hooks
|
||||||
register_ophooks_recursively(self.module,
|
self._ophook_list = [ZeroHook(self.shard_strategy, self._memstats_collector, self.process_group)]
|
||||||
[ZeroHook(self.shard_strategy, self._memstats_collector, self.process_group)])
|
register_ophooks_recursively(self.module, self._ophook_list)
|
||||||
self.param_hook_mgr = BaseParamHookMgr(list(self.module.parameters()))
|
self.param_hook_mgr = BaseParamHookMgr(list(self.module.parameters()))
|
||||||
self.param_hook_mgr.register_backward_hooks(self._grad_post_backward_hook)
|
self.param_hook_mgr.register_backward_hooks(self._grad_post_backward_hook)
|
||||||
|
|
||||||
@ -134,10 +134,14 @@ class ShardedModelV2(nn.Module):
|
|||||||
def backward(self, loss):
|
def backward(self, loss):
|
||||||
loss.backward()
|
loss.backward()
|
||||||
self._post_backward_operations()
|
self._post_backward_operations()
|
||||||
|
for ophook in self._ophook_list:
|
||||||
|
ophook.post_iter()
|
||||||
|
|
||||||
def backward_by_grad(self, tensor, grad):
|
def backward_by_grad(self, tensor, grad):
|
||||||
torch.autograd.backward(tensors=tensor, grad_tensors=grad)
|
torch.autograd.backward(tensors=tensor, grad_tensors=grad)
|
||||||
self._post_backward_operations()
|
self._post_backward_operations()
|
||||||
|
for ophook in self._ophook_list:
|
||||||
|
ophook.post_iter()
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def _post_backward_operations(self) -> None:
|
def _post_backward_operations(self) -> None:
|
||||||
|
Loading…
Reference in New Issue
Block a user