mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-01-16 07:28:45 +00:00
only process module's own parameters in Zero context add zero hooks for all modules that contrain parameters gather parameters only belonging to module itself
113 lines
3.9 KiB
Python
113 lines
3.9 KiB
Python
from typing import List
|
|
|
|
import torch
|
|
|
|
from ._base_ophook import BaseOpHook
|
|
from ._memtracer_ophook import MemTracerOpHook
|
|
from ._shard_grad_ophook import ShardGradHook
|
|
from ._shard_param_ophook import ShardParamHook
|
|
|
|
all = ["BaseOpHook", "MemTracerOpHook", "register_ophooks_recursively", "ShardParamHook", "ShardGradHook"]
|
|
|
|
|
|
# apply torch.autograd.Function that calls a backward_function to tensors in output
|
|
def _apply_to_tensors_only(module, functional, backward_function, outputs):
|
|
if type(outputs) is tuple:
|
|
touched_outputs = []
|
|
for output in outputs:
|
|
touched_output = _apply_to_tensors_only(module, functional, backward_function, output)
|
|
touched_outputs.append(touched_output)
|
|
return tuple(touched_outputs)
|
|
elif type(outputs) is torch.Tensor:
|
|
return functional.apply(module, backward_function, outputs)
|
|
else:
|
|
return outputs
|
|
|
|
|
|
class PreBackwardFunction(torch.autograd.Function):
|
|
|
|
@staticmethod
|
|
def forward(ctx, module, pre_backward_function, outputs):
|
|
ctx.module = module
|
|
ctx.pre_backward_function = pre_backward_function
|
|
module.applied_pre_backward = False
|
|
outputs = outputs.detach()
|
|
return outputs
|
|
|
|
@staticmethod
|
|
def backward(ctx, *args):
|
|
ctx.pre_backward_function(ctx.module)
|
|
return (None, None) + args
|
|
|
|
|
|
class PostBackwardFunction(torch.autograd.Function):
|
|
|
|
@staticmethod
|
|
def forward(ctx, module, pre_backward_function, output):
|
|
ctx.module = module
|
|
output = output.detach()
|
|
ctx.pre_backward_function = pre_backward_function
|
|
return output
|
|
|
|
@staticmethod
|
|
def backward(ctx, *args):
|
|
"""
|
|
Args:
|
|
activation_grad of the next layer.
|
|
Returns:
|
|
grad of the input activation.
|
|
"""
|
|
ctx.pre_backward_function(ctx.module)
|
|
return (None, None) + args
|
|
|
|
|
|
def register_ophooks_recursively(module: torch.nn.Module, ophook_list: List[BaseOpHook] = None, name: str = ""):
|
|
r"""Recursilvely register pre/post hooks for all submodules in the module in FWD and BWD."""
|
|
assert isinstance(module, torch.nn.Module)
|
|
|
|
# Add hooks for submodules
|
|
for child_name, child in module.named_children():
|
|
register_ophooks_recursively(child, ophook_list, name + child_name)
|
|
|
|
# Early return on modules with no parameters.
|
|
if len(list(module.parameters(recurse=False))) == 0:
|
|
return
|
|
|
|
if ophook_list is not None:
|
|
for hook in ophook_list:
|
|
assert (isinstance(hook, BaseOpHook))
|
|
|
|
def _pre_forward_module_hook(submodule, *args):
|
|
for hook in ophook_list:
|
|
assert isinstance(submodule, torch.nn.Module)
|
|
hook.pre_fwd_exec(submodule, *args)
|
|
|
|
def _post_forward_module_hook(submodule, *args):
|
|
for hook in ophook_list:
|
|
assert isinstance(submodule, torch.nn.Module)
|
|
hook.post_fwd_exec(submodule, *args)
|
|
|
|
def _pre_backward_module_hook(submodule, inputs, output):
|
|
|
|
def _run_before_backward_function(submodule):
|
|
for hook in ophook_list:
|
|
assert isinstance(submodule, torch.nn.Module)
|
|
hook.pre_bwd_exec(submodule, inputs, output)
|
|
|
|
return _apply_to_tensors_only(submodule, PreBackwardFunction, _run_before_backward_function, output)
|
|
|
|
def _post_backward_module_hook(submodule, inputs):
|
|
|
|
def _run_after_backward_function(submodule):
|
|
for hook in ophook_list:
|
|
assert isinstance(submodule, torch.nn.Module)
|
|
hook.post_bwd_exec(submodule, inputs)
|
|
|
|
return _apply_to_tensors_only(submodule, PostBackwardFunction, _run_after_backward_function, inputs)
|
|
|
|
module.register_forward_pre_hook(_pre_forward_module_hook)
|
|
module.register_forward_hook(_post_forward_module_hook)
|
|
|
|
module.register_forward_hook(_pre_backward_module_hook)
|
|
module.register_forward_pre_hook(_post_backward_module_hook)
|