[zero] Update sharded model v2 using sharded param v2 (#323)

This commit is contained in:
ver217
2022-03-08 18:18:06 +08:00
committed by Frank Lee
parent 799d105bb4
commit 1388671699
16 changed files with 403 additions and 202 deletions

View File

@@ -15,8 +15,7 @@ def _apply_to_tensors_only(module, functional, backward_function, outputs):
if type(outputs) is tuple:
touched_outputs = []
for output in outputs:
touched_output = _apply_to_tensors_only(module, functional,
backward_function, output)
touched_output = _apply_to_tensors_only(module, functional, backward_function, output)
touched_outputs.append(touched_output)
return tuple(touched_outputs)
elif type(outputs) is torch.Tensor:
@@ -26,6 +25,7 @@ def _apply_to_tensors_only(module, functional, backward_function, outputs):
class PreBackwardFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, module, pre_backward_function, outputs):
ctx.module = module
@@ -41,6 +41,7 @@ class PreBackwardFunction(torch.autograd.Function):
class PostBackwardFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, module, pre_backward_function, output):
ctx.module = module
@@ -60,9 +61,7 @@ class PostBackwardFunction(torch.autograd.Function):
return (None, None) + args
def register_ophooks_recursively(module: torch.nn.Module,
ophook_list: List[BaseOpHook] = None,
name: str = ""):
def register_ophooks_recursively(module: torch.nn.Module, ophook_list: List[BaseOpHook] = None, name: str = ""):
r"""Recursilvely register pre/post hooks for all submodules in the module in FWD and BWD."""
assert isinstance(module, torch.nn.Module)
has_children = False
@@ -72,8 +71,7 @@ def register_ophooks_recursively(module: torch.nn.Module,
# Early return on modules with no parameters or buffers that
# are not in their children.
if (len(list(module.named_parameters(recurse=False))) == 0
and len(list(module.named_buffers(recurse=False))) == 0):
if (len(list(module.named_parameters(recurse=False))) == 0 and len(list(module.named_buffers(recurse=False))) == 0):
return
# return if the module has not childern.
@@ -95,22 +93,22 @@ def register_ophooks_recursively(module: torch.nn.Module,
hook.post_fwd_exec(submodule, *args)
def _pre_backward_module_hook(submodule, inputs, output):
def _run_before_backward_function(submodule):
for hook in ophook_list:
assert isinstance(submodule, torch.nn.Module)
hook.pre_bwd_exec(submodule, inputs, output)
return _apply_to_tensors_only(submodule, PreBackwardFunction,
_run_before_backward_function, output)
return _apply_to_tensors_only(submodule, PreBackwardFunction, _run_before_backward_function, output)
def _post_backward_module_hook(submodule, inputs):
def _run_after_backward_function(submodule):
for hook in ophook_list:
assert isinstance(submodule, torch.nn.Module)
hook.post_bwd_exec(submodule, inputs)
return _apply_to_tensors_only(submodule, PostBackwardFunction,
_run_after_backward_function, inputs)
return _apply_to_tensors_only(submodule, PostBackwardFunction, _run_after_backward_function, inputs)
module.register_forward_pre_hook(_pre_forward_module_hook)
module.register_forward_hook(_post_forward_module_hook)

View File

@@ -0,0 +1,58 @@
import torch
from colossalai.registry import OPHOOKS
from colossalai.zero.shard_utils import BaseShardStrategy
from ._base_ophook import BaseOpHook
@OPHOOKS.register_module
class ZeroHook(BaseOpHook):
"""
A hook to process sharded param for ZeRO method.
"""
def __init__(self, shard_strategy: BaseShardStrategy):
super().__init__()
self.shard_strategy = shard_strategy
def pre_fwd_exec(self, module: torch.nn.Module, *args):
for param in module.parameters():
assert hasattr(param, 'col_attr')
self.shard_strategy.gather([param.col_attr.data])
param.data = param.col_attr.data.payload
def post_fwd_exec(self, module: torch.nn.Module, *args):
for param in module.parameters():
assert hasattr(param, 'col_attr')
self.shard_strategy.shard([param.col_attr.data])
param.data = torch.empty([], dtype=param.col_attr.data.dtype, device=param.col_attr.data.payload.device)
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
for param in module.parameters():
assert hasattr(param, 'col_attr')
self.shard_strategy.gather([param.col_attr.data])
param.data = param.col_attr.data.payload
# Store local accumulated grad shard
if param.grad is not None:
if param.col_attr.bwd_count == 0:
# We haven't stored local accumulated grad yet
assert param.col_attr.grad is None
param.col_attr.grad = param.grad.data
param.grad = None
else:
# We have stored local accumulated grad
# The grad here must be locally computed full grad in this backward pass
assert param.grad.shape == param.col_attr.data.origin_shape
param.col_attr.bwd_count += 1
def post_bwd_exec(self, module: torch.nn.Module, input):
for param in module.parameters():
assert hasattr(param, 'col_attr')
self.shard_strategy.shard([param.col_attr.data])
param.data = torch.empty([], dtype=param.col_attr.data.dtype, device=param.col_attr.data.payload.device)
def pre_iter(self):
pass
def post_iter(self):
pass