mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 09:07:51 +00:00
[hotfix] fix param op hook (#1131)
* fix param op hook * update zero tp test * fix bugs
This commit is contained in:
@@ -11,17 +11,11 @@ def filter_args(func, *args):
|
||||
return [arg for arg in args if func(arg)]
|
||||
|
||||
|
||||
def unpack_args(*args):
|
||||
if len(args) == 1:
|
||||
return args[0]
|
||||
return args
|
||||
|
||||
|
||||
def replace_args(args, kwargs, new_args):
|
||||
args = new_args[:len(args)]
|
||||
for k, v in zip(kwargs.keys(), new_args[len(args):]):
|
||||
kwargs[k] = v
|
||||
return unpack_args(args), kwargs
|
||||
return tuple(args), kwargs
|
||||
|
||||
|
||||
class ColoParameter(ColoTensor, torch.nn.Parameter):
|
||||
|
@@ -2,6 +2,7 @@ import torch
|
||||
from contextlib import contextmanager
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Tuple, Any
|
||||
from colossalai.tensor.colo_tensor import ColoTensor
|
||||
|
||||
|
||||
class ParamOpHook(ABC):
|
||||
@@ -74,14 +75,18 @@ class ParamOpHookManager:
|
||||
hook.post_backward(params)
|
||||
|
||||
@staticmethod
|
||||
def pre_op(params: List[torch.Tensor], *args: Any) -> Any:
|
||||
def pre_op(params: List[torch.Tensor], *args: Any) -> list:
|
||||
ParamOpHookManager._trigger_pre_forward(params)
|
||||
return PreFwdPostBwd.apply(params, *args)
|
||||
args_info = _get_colo_tensors_info(*args)
|
||||
rets = PreFwdPostBwd.apply(params, *args)
|
||||
return _update_colo_tensors(args_info, *rets)
|
||||
|
||||
@staticmethod
|
||||
def post_op(params: List[torch.Tensor], args: Any) -> Any:
|
||||
def post_op(params: List[torch.Tensor], arg: Any) -> Any:
|
||||
ParamOpHookManager._trigger_post_forward(params)
|
||||
return PostFwdPreBwd.apply(params, args)
|
||||
arg_info = _get_colo_tensors_info(arg)
|
||||
ret = PostFwdPreBwd.apply(params, arg)
|
||||
return _unpack_args(_update_colo_tensors(arg_info, ret))
|
||||
|
||||
@staticmethod
|
||||
def has_hook() -> bool:
|
||||
@@ -93,9 +98,7 @@ class PreFwdPostBwd(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, params, *args):
|
||||
ctx.params = params
|
||||
if len(args) == 1:
|
||||
return args[0]
|
||||
return args
|
||||
return _unpack_args(args)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *grads):
|
||||
@@ -114,3 +117,29 @@ class PostFwdPreBwd(torch.autograd.Function):
|
||||
def backward(ctx, *grads):
|
||||
ParamOpHookManager._trigger_pre_backward(ctx.params)
|
||||
return (None,) + grads
|
||||
|
||||
|
||||
def _unpack_args(args):
|
||||
if len(args) == 1:
|
||||
return args[0]
|
||||
return args
|
||||
|
||||
|
||||
def _get_colo_tensors_info(*args) -> list:
|
||||
info = []
|
||||
for arg in args:
|
||||
if isinstance(arg, ColoTensor):
|
||||
info.append((arg.__class__, arg.spec))
|
||||
else:
|
||||
info.append(None)
|
||||
return info
|
||||
|
||||
|
||||
def _update_colo_tensors(info, *args) -> list:
|
||||
ret = []
|
||||
for t_info, arg in zip(info, args):
|
||||
if t_info is not None:
|
||||
t_cls, spec = t_info
|
||||
arg = t_cls.from_torch_tensor(arg, spec=spec)
|
||||
ret.append(arg)
|
||||
return ret
|
||||
|
Reference in New Issue
Block a user