mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 03:20:52 +00:00
[legacy] clean up legacy code (#4743)
* [legacy] remove outdated codes of pipeline (#4692) * [legacy] remove cli of benchmark and update optim (#4690) * [legacy] remove cli of benchmark and update optim * [doc] fix cli doc test * [legacy] fix engine clip grad norm * [legacy] remove outdated colo tensor (#4694) * [legacy] remove outdated colo tensor * [test] fix test import * [legacy] move outdated zero to legacy (#4696) * [legacy] clean up utils (#4700) * [legacy] clean up utils * [example] update examples * [legacy] clean up amp * [legacy] fix amp module * [legacy] clean up gpc (#4742) * [legacy] clean up context * [legacy] clean core, constants and global vars * [legacy] refactor initialize * [example] fix examples ci * [example] fix examples ci * [legacy] fix tests * [example] fix gpt example * [example] fix examples ci * [devops] fix ci installation * [example] fix examples ci
This commit is contained in:
142
colossalai/legacy/zero/gemini/ophooks/utils.py
Normal file
142
colossalai/legacy/zero/gemini/ophooks/utils.py
Normal file
@@ -0,0 +1,142 @@
|
||||
# this code is inspired by the DeepSpeed library and implemented with our own design from scratch
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class BaseOpHook(ABC):
|
||||
"""This class allows users to add customized operations
|
||||
before and after the execution of a PyTorch submodule"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def pre_fwd_exec(self, module: torch.nn.Module, *args):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def post_fwd_exec(self, module: torch.nn.Module, *args):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def post_bwd_exec(self, module: torch.nn.Module, input):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def post_iter(self):
|
||||
pass
|
||||
|
||||
|
||||
# apply torch.autograd.Function that calls a backward_function to tensors in output
|
||||
def _apply_to_tensors_only(module, functional, backward_function, outputs):
|
||||
if type(outputs) is tuple:
|
||||
touched_outputs = []
|
||||
for output in outputs:
|
||||
touched_output = _apply_to_tensors_only(module, functional, backward_function, output)
|
||||
touched_outputs.append(touched_output)
|
||||
return tuple(touched_outputs)
|
||||
elif type(outputs) is torch.Tensor:
|
||||
return functional.apply(module, backward_function, outputs)
|
||||
else:
|
||||
return outputs
|
||||
|
||||
|
||||
class PreBackwardFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, module, pre_backward_function, outputs):
|
||||
ctx.module = module
|
||||
ctx.pre_backward_function = pre_backward_function
|
||||
module.applied_pre_backward = False
|
||||
outputs = outputs.detach()
|
||||
return outputs
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *args):
|
||||
ctx.pre_backward_function(ctx.module)
|
||||
return (None, None) + args
|
||||
|
||||
|
||||
class PostBackwardFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, module, pre_backward_function, output):
|
||||
ctx.module = module
|
||||
output = output.detach()
|
||||
ctx.pre_backward_function = pre_backward_function
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *args):
|
||||
"""
|
||||
Args:
|
||||
activation_grad of the next layer.
|
||||
Returns:
|
||||
grad of the input activation.
|
||||
"""
|
||||
ctx.pre_backward_function(ctx.module)
|
||||
return (None, None) + args
|
||||
|
||||
|
||||
def register_ophooks_recursively(module: torch.nn.Module,
|
||||
ophook_list: List[BaseOpHook],
|
||||
name: str = "",
|
||||
filter_fn: Optional[Callable] = None):
|
||||
r"""Recursively register pre/post hooks for all submodules in the module in FWD and BWD."""
|
||||
assert isinstance(module, torch.nn.Module)
|
||||
assert isinstance(ophook_list, (list, tuple))
|
||||
assert len(ophook_list) > 0, 'expected at least 1 hook in the argument ophook_list but found 0'
|
||||
for hook in ophook_list:
|
||||
assert (isinstance(hook, BaseOpHook))
|
||||
|
||||
# Add hooks for submodules
|
||||
for child_name, child in module.named_children():
|
||||
register_ophooks_recursively(child, ophook_list, name + child_name, filter_fn)
|
||||
|
||||
# Early return on modules with no parameters.
|
||||
if len(list(module.parameters(recurse=False))) == 0:
|
||||
return
|
||||
|
||||
# return from filtered module
|
||||
if filter_fn is not None and filter_fn(module):
|
||||
return
|
||||
|
||||
def _pre_forward_module_hook(submodule, *args):
|
||||
for hook in ophook_list:
|
||||
assert isinstance(submodule, torch.nn.Module)
|
||||
hook.pre_fwd_exec(submodule, *args)
|
||||
|
||||
def _post_forward_module_hook(submodule, *args):
|
||||
for hook in ophook_list:
|
||||
assert isinstance(submodule, torch.nn.Module)
|
||||
hook.post_fwd_exec(submodule, *args)
|
||||
|
||||
def _pre_backward_module_hook(submodule, inputs, output):
|
||||
|
||||
def _run_before_backward_function(submodule):
|
||||
for hook in ophook_list:
|
||||
assert isinstance(submodule, torch.nn.Module)
|
||||
hook.pre_bwd_exec(submodule, inputs, output)
|
||||
|
||||
return _apply_to_tensors_only(submodule, PreBackwardFunction, _run_before_backward_function, output)
|
||||
|
||||
def _post_backward_module_hook(submodule, inputs):
|
||||
|
||||
def _run_after_backward_function(submodule):
|
||||
for hook in ophook_list:
|
||||
assert isinstance(submodule, torch.nn.Module)
|
||||
hook.post_bwd_exec(submodule, inputs)
|
||||
|
||||
return _apply_to_tensors_only(submodule, PostBackwardFunction, _run_after_backward_function, inputs)
|
||||
|
||||
module.register_forward_pre_hook(_pre_forward_module_hook)
|
||||
module.register_forward_hook(_post_forward_module_hook)
|
||||
|
||||
module.register_forward_hook(_pre_backward_module_hook)
|
||||
module.register_forward_pre_hook(_post_backward_module_hook)
|
Reference in New Issue
Block a user