mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-06 06:02:16 +00:00
* impl chunk manager * impl param op hook * add reduce_chunk * add zero hook v2 * add zero dp * fix TensorInfo * impl load balancing when using zero without chunk * fix zero hook * polish chunk * fix bugs * ddp ok * zero ok * polish code * fix bugs about load balancing * polish code * polish code * add ene-to-end test * polish code * polish code * polish code * fix typo * add test_chunk * fix bugs * fix bugs * polish code
72 lines
1.7 KiB
Python
72 lines
1.7 KiB
Python
import torch
|
|
from contextlib import contextmanager
|
|
from abc import ABC, abstractmethod
|
|
from typing import List, Tuple
|
|
|
|
|
|
class ParamOpHook(ABC):
|
|
|
|
@abstractmethod
|
|
def pre_forward(self, params: List[torch.Tensor]) -> None:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def post_forward(self, params: List[torch.Tensor]) -> None:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def pre_backward(self, params: List[torch.Tensor]) -> None:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def post_backward(self, params: List[torch.Tensor]) -> None:
|
|
pass
|
|
|
|
|
|
class _ParamOpHookWrapper:
|
|
hooks: Tuple[ParamOpHook, ...] = tuple()
|
|
|
|
|
|
class PreFwdPostBwd(torch.autograd.Function):
|
|
|
|
@staticmethod
|
|
def forward(ctx, params, *args):
|
|
ctx.params = params
|
|
for hook in _ParamOpHookWrapper.hooks:
|
|
hook.pre_forward(ctx.params)
|
|
if len(args) == 1:
|
|
return args[0]
|
|
return args
|
|
|
|
@staticmethod
|
|
def backward(ctx, *grads):
|
|
for hook in _ParamOpHookWrapper.hooks:
|
|
hook.post_backward(ctx.params)
|
|
return (None,) + grads
|
|
|
|
|
|
class PostFwdPreBwd(torch.autograd.Function):
|
|
|
|
@staticmethod
|
|
def forward(ctx, params, args):
|
|
ctx.params = params
|
|
for hook in _ParamOpHookWrapper.hooks:
|
|
hook.post_forward(params)
|
|
return args
|
|
|
|
@staticmethod
|
|
def backward(ctx, *grads):
|
|
for hook in _ParamOpHookWrapper.hooks:
|
|
hook.pre_backward(ctx.params)
|
|
return (None,) + grads
|
|
|
|
|
|
@contextmanager
|
|
def use_param_op_hooks(*hooks: ParamOpHook):
|
|
try:
|
|
old_param_op_hooks = _ParamOpHookWrapper.hooks
|
|
_ParamOpHookWrapper.hooks = hooks
|
|
yield
|
|
finally:
|
|
_ParamOpHookWrapper.hooks = old_param_op_hooks
|