mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-22 09:59:38 +00:00
[zero] Update sharded model v2 using sharded param v2 (#323)
This commit is contained in:
@@ -15,8 +15,7 @@ def _apply_to_tensors_only(module, functional, backward_function, outputs):
|
||||
if type(outputs) is tuple:
|
||||
touched_outputs = []
|
||||
for output in outputs:
|
||||
touched_output = _apply_to_tensors_only(module, functional,
|
||||
backward_function, output)
|
||||
touched_output = _apply_to_tensors_only(module, functional, backward_function, output)
|
||||
touched_outputs.append(touched_output)
|
||||
return tuple(touched_outputs)
|
||||
elif type(outputs) is torch.Tensor:
|
||||
@@ -26,6 +25,7 @@ def _apply_to_tensors_only(module, functional, backward_function, outputs):
|
||||
|
||||
|
||||
class PreBackwardFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, module, pre_backward_function, outputs):
|
||||
ctx.module = module
|
||||
@@ -41,6 +41,7 @@ class PreBackwardFunction(torch.autograd.Function):
|
||||
|
||||
|
||||
class PostBackwardFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, module, pre_backward_function, output):
|
||||
ctx.module = module
|
||||
@@ -60,9 +61,7 @@ class PostBackwardFunction(torch.autograd.Function):
|
||||
return (None, None) + args
|
||||
|
||||
|
||||
def register_ophooks_recursively(module: torch.nn.Module,
|
||||
ophook_list: List[BaseOpHook] = None,
|
||||
name: str = ""):
|
||||
def register_ophooks_recursively(module: torch.nn.Module, ophook_list: List[BaseOpHook] = None, name: str = ""):
|
||||
r"""Recursilvely register pre/post hooks for all submodules in the module in FWD and BWD."""
|
||||
assert isinstance(module, torch.nn.Module)
|
||||
has_children = False
|
||||
@@ -72,8 +71,7 @@ def register_ophooks_recursively(module: torch.nn.Module,
|
||||
|
||||
# Early return on modules with no parameters or buffers that
|
||||
# are not in their children.
|
||||
if (len(list(module.named_parameters(recurse=False))) == 0
|
||||
and len(list(module.named_buffers(recurse=False))) == 0):
|
||||
if (len(list(module.named_parameters(recurse=False))) == 0 and len(list(module.named_buffers(recurse=False))) == 0):
|
||||
return
|
||||
|
||||
# return if the module has not childern.
|
||||
@@ -95,22 +93,22 @@ def register_ophooks_recursively(module: torch.nn.Module,
|
||||
hook.post_fwd_exec(submodule, *args)
|
||||
|
||||
def _pre_backward_module_hook(submodule, inputs, output):
|
||||
|
||||
def _run_before_backward_function(submodule):
|
||||
for hook in ophook_list:
|
||||
assert isinstance(submodule, torch.nn.Module)
|
||||
hook.pre_bwd_exec(submodule, inputs, output)
|
||||
|
||||
return _apply_to_tensors_only(submodule, PreBackwardFunction,
|
||||
_run_before_backward_function, output)
|
||||
return _apply_to_tensors_only(submodule, PreBackwardFunction, _run_before_backward_function, output)
|
||||
|
||||
def _post_backward_module_hook(submodule, inputs):
|
||||
|
||||
def _run_after_backward_function(submodule):
|
||||
for hook in ophook_list:
|
||||
assert isinstance(submodule, torch.nn.Module)
|
||||
hook.post_bwd_exec(submodule, inputs)
|
||||
|
||||
return _apply_to_tensors_only(submodule, PostBackwardFunction,
|
||||
_run_after_backward_function, inputs)
|
||||
return _apply_to_tensors_only(submodule, PostBackwardFunction, _run_after_backward_function, inputs)
|
||||
|
||||
module.register_forward_pre_hook(_pre_forward_module_hook)
|
||||
module.register_forward_hook(_post_forward_module_hook)
|
||||
|
58
colossalai/engine/ophooks/zero_hook.py
Normal file
58
colossalai/engine/ophooks/zero_hook.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import torch
|
||||
from colossalai.registry import OPHOOKS
|
||||
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||
|
||||
from ._base_ophook import BaseOpHook
|
||||
|
||||
|
||||
@OPHOOKS.register_module
|
||||
class ZeroHook(BaseOpHook):
|
||||
"""
|
||||
A hook to process sharded param for ZeRO method.
|
||||
"""
|
||||
|
||||
def __init__(self, shard_strategy: BaseShardStrategy):
|
||||
super().__init__()
|
||||
self.shard_strategy = shard_strategy
|
||||
|
||||
def pre_fwd_exec(self, module: torch.nn.Module, *args):
|
||||
for param in module.parameters():
|
||||
assert hasattr(param, 'col_attr')
|
||||
self.shard_strategy.gather([param.col_attr.data])
|
||||
param.data = param.col_attr.data.payload
|
||||
|
||||
def post_fwd_exec(self, module: torch.nn.Module, *args):
|
||||
for param in module.parameters():
|
||||
assert hasattr(param, 'col_attr')
|
||||
self.shard_strategy.shard([param.col_attr.data])
|
||||
param.data = torch.empty([], dtype=param.col_attr.data.dtype, device=param.col_attr.data.payload.device)
|
||||
|
||||
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
|
||||
for param in module.parameters():
|
||||
assert hasattr(param, 'col_attr')
|
||||
self.shard_strategy.gather([param.col_attr.data])
|
||||
param.data = param.col_attr.data.payload
|
||||
# Store local accumulated grad shard
|
||||
if param.grad is not None:
|
||||
if param.col_attr.bwd_count == 0:
|
||||
# We haven't stored local accumulated grad yet
|
||||
assert param.col_attr.grad is None
|
||||
param.col_attr.grad = param.grad.data
|
||||
param.grad = None
|
||||
else:
|
||||
# We have stored local accumulated grad
|
||||
# The grad here must be locally computed full grad in this backward pass
|
||||
assert param.grad.shape == param.col_attr.data.origin_shape
|
||||
param.col_attr.bwd_count += 1
|
||||
|
||||
def post_bwd_exec(self, module: torch.nn.Module, input):
|
||||
for param in module.parameters():
|
||||
assert hasattr(param, 'col_attr')
|
||||
self.shard_strategy.shard([param.col_attr.data])
|
||||
param.data = torch.empty([], dtype=param.col_attr.data.dtype, device=param.col_attr.data.payload.device)
|
||||
|
||||
def pre_iter(self):
|
||||
pass
|
||||
|
||||
def post_iter(self):
|
||||
pass
|
Reference in New Issue
Block a user