ColossalAI/colossalai/tensor/param_op_hook.py
ver217 9492a561c3
[tensor] ColoTensor supports ZeRo (#1015)
* impl chunk manager

* impl param op hook

* add reduce_chunk

* add zero hook v2

* add zero dp

* fix TensorInfo

* impl load balancing when using zero without chunk

* fix zero hook

* polish chunk

* fix bugs

* ddp ok

* zero ok

* polish code

* fix bugs about load balancing

* polish code

* polish code

* add ene-to-end test

* polish code

* polish code

* polish code

* fix typo

* add test_chunk

* fix bugs

* fix bugs

* polish code
2022-05-31 12:00:12 +08:00

72 lines
1.7 KiB
Python

import torch
from contextlib import contextmanager
from abc import ABC, abstractmethod
from typing import List, Tuple
class ParamOpHook(ABC):
@abstractmethod
def pre_forward(self, params: List[torch.Tensor]) -> None:
pass
@abstractmethod
def post_forward(self, params: List[torch.Tensor]) -> None:
pass
@abstractmethod
def pre_backward(self, params: List[torch.Tensor]) -> None:
pass
@abstractmethod
def post_backward(self, params: List[torch.Tensor]) -> None:
pass
class _ParamOpHookWrapper:
hooks: Tuple[ParamOpHook, ...] = tuple()
class PreFwdPostBwd(torch.autograd.Function):
@staticmethod
def forward(ctx, params, *args):
ctx.params = params
for hook in _ParamOpHookWrapper.hooks:
hook.pre_forward(ctx.params)
if len(args) == 1:
return args[0]
return args
@staticmethod
def backward(ctx, *grads):
for hook in _ParamOpHookWrapper.hooks:
hook.post_backward(ctx.params)
return (None,) + grads
class PostFwdPreBwd(torch.autograd.Function):
@staticmethod
def forward(ctx, params, args):
ctx.params = params
for hook in _ParamOpHookWrapper.hooks:
hook.post_forward(params)
return args
@staticmethod
def backward(ctx, *grads):
for hook in _ParamOpHookWrapper.hooks:
hook.pre_backward(ctx.params)
return (None,) + grads
@contextmanager
def use_param_op_hooks(*hooks: ParamOpHook):
try:
old_param_op_hooks = _ParamOpHookWrapper.hooks
_ParamOpHookWrapper.hooks = hooks
yield
finally:
_ParamOpHookWrapper.hooks = old_param_op_hooks