mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 04:24:47 +00:00
[tensor] ColoTensor supports ZeRo (#1015)
* impl chunk manager * impl param op hook * add reduce_chunk * add zero hook v2 * add zero dp * fix TensorInfo * impl load balancing when using zero without chunk * fix zero hook * polish chunk * fix bugs * ddp ok * zero ok * polish code * fix bugs about load balancing * polish code * polish code * add ene-to-end test * polish code * polish code * polish code * fix typo * add test_chunk * fix bugs * fix bugs * polish code
This commit is contained in:
71
colossalai/tensor/param_op_hook.py
Normal file
71
colossalai/tensor/param_op_hook.py
Normal file
@@ -0,0 +1,71 @@
|
||||
import torch
|
||||
from contextlib import contextmanager
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Tuple
|
||||
|
||||
|
||||
class ParamOpHook(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def pre_forward(self, params: List[torch.Tensor]) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def post_forward(self, params: List[torch.Tensor]) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def pre_backward(self, params: List[torch.Tensor]) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def post_backward(self, params: List[torch.Tensor]) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class _ParamOpHookWrapper:
|
||||
hooks: Tuple[ParamOpHook, ...] = tuple()
|
||||
|
||||
|
||||
class PreFwdPostBwd(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, params, *args):
|
||||
ctx.params = params
|
||||
for hook in _ParamOpHookWrapper.hooks:
|
||||
hook.pre_forward(ctx.params)
|
||||
if len(args) == 1:
|
||||
return args[0]
|
||||
return args
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *grads):
|
||||
for hook in _ParamOpHookWrapper.hooks:
|
||||
hook.post_backward(ctx.params)
|
||||
return (None,) + grads
|
||||
|
||||
|
||||
class PostFwdPreBwd(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, params, args):
|
||||
ctx.params = params
|
||||
for hook in _ParamOpHookWrapper.hooks:
|
||||
hook.post_forward(params)
|
||||
return args
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *grads):
|
||||
for hook in _ParamOpHookWrapper.hooks:
|
||||
hook.pre_backward(ctx.params)
|
||||
return (None,) + grads
|
||||
|
||||
|
||||
@contextmanager
|
||||
def use_param_op_hooks(*hooks: ParamOpHook):
|
||||
try:
|
||||
old_param_op_hooks = _ParamOpHookWrapper.hooks
|
||||
_ParamOpHookWrapper.hooks = hooks
|
||||
yield
|
||||
finally:
|
||||
_ParamOpHookWrapper.hooks = old_param_op_hooks
|
Reference in New Issue
Block a user