[hotfix] fix param op hook (#1131)

* fix param op hook

* update zero tp test

* fix bugs
This commit is contained in:
ver217
2022-06-17 16:12:05 +08:00
committed by GitHub
parent a1a7899cae
commit 789cad301b
3 changed files with 74 additions and 20 deletions

View File

@@ -11,17 +11,11 @@ def filter_args(func, *args):
return [arg for arg in args if func(arg)]
def unpack_args(*args):
if len(args) == 1:
return args[0]
return args
def replace_args(args, kwargs, new_args):
args = new_args[:len(args)]
for k, v in zip(kwargs.keys(), new_args[len(args):]):
kwargs[k] = v
return unpack_args(args), kwargs
return tuple(args), kwargs
class ColoParameter(ColoTensor, torch.nn.Parameter):

View File

@@ -2,6 +2,7 @@ import torch
from contextlib import contextmanager
from abc import ABC, abstractmethod
from typing import List, Tuple, Any
from colossalai.tensor.colo_tensor import ColoTensor
class ParamOpHook(ABC):
@@ -74,14 +75,18 @@ class ParamOpHookManager:
hook.post_backward(params)
@staticmethod
def pre_op(params: List[torch.Tensor], *args: Any) -> Any:
def pre_op(params: List[torch.Tensor], *args: Any) -> list:
ParamOpHookManager._trigger_pre_forward(params)
return PreFwdPostBwd.apply(params, *args)
args_info = _get_colo_tensors_info(*args)
rets = PreFwdPostBwd.apply(params, *args)
return _update_colo_tensors(args_info, *rets)
@staticmethod
def post_op(params: List[torch.Tensor], args: Any) -> Any:
def post_op(params: List[torch.Tensor], arg: Any) -> Any:
ParamOpHookManager._trigger_post_forward(params)
return PostFwdPreBwd.apply(params, args)
arg_info = _get_colo_tensors_info(arg)
ret = PostFwdPreBwd.apply(params, arg)
return _unpack_args(_update_colo_tensors(arg_info, ret))
@staticmethod
def has_hook() -> bool:
@@ -93,9 +98,7 @@ class PreFwdPostBwd(torch.autograd.Function):
@staticmethod
def forward(ctx, params, *args):
ctx.params = params
if len(args) == 1:
return args[0]
return args
return _unpack_args(args)
@staticmethod
def backward(ctx, *grads):
@@ -114,3 +117,29 @@ class PostFwdPreBwd(torch.autograd.Function):
def backward(ctx, *grads):
ParamOpHookManager._trigger_pre_backward(ctx.params)
return (None,) + grads
def _unpack_args(args):
if len(args) == 1:
return args[0]
return args
def _get_colo_tensors_info(*args) -> list:
info = []
for arg in args:
if isinstance(arg, ColoTensor):
info.append((arg.__class__, arg.spec))
else:
info.append(None)
return info
def _update_colo_tensors(info, *args) -> list:
ret = []
for t_info, arg in zip(info, args):
if t_info is not None:
t_cls, spec = t_info
arg = t_cls.from_torch_tensor(arg, spec=spec)
ret.append(arg)
return ret