mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 03:52:01 +00:00
[zero] fix error for BEiT models (#2169)
* [zero] fix error for BEiT models * [ColoParameter] add unpack operation for tuple arguments * fix bugs * fix chunkv2 unit testing * add assertion for gradient state
This commit is contained in:
@@ -82,16 +82,26 @@ class ColoParamOpHookManager:
|
||||
@staticmethod
|
||||
def pre_op(params: List[torch.Tensor], *args: Any) -> list:
|
||||
ColoParamOpHookManager._trigger_pre_forward(params)
|
||||
args_info = _get_colo_tensors_info(*args)
|
||||
rets = PreFwdPostBwd.apply(params, *args)
|
||||
return _update_colo_tensors(args_info, *rets)
|
||||
grad_args, rear_args = _get_grad_args(*args)
|
||||
colo_info = _get_colo_tensors_info(*grad_args)
|
||||
rets = PreFwdPostBwd.apply(params, *grad_args)
|
||||
update_args = _update_colo_tensors(colo_info, *rets)
|
||||
if rear_args is None:
|
||||
return update_args
|
||||
else:
|
||||
arg_zero = (tuple(update_args),)
|
||||
return arg_zero + rear_args
|
||||
|
||||
@staticmethod
|
||||
def post_op(params: List[torch.Tensor], arg: Any) -> Any:
|
||||
ColoParamOpHookManager._trigger_post_forward(params)
|
||||
arg_info = _get_colo_tensors_info(arg)
|
||||
colo_info = _get_colo_tensors_info(arg)
|
||||
ret = PostFwdPreBwd.apply(params, arg)
|
||||
return _unpack_args(_update_colo_tensors(arg_info, ret))
|
||||
res = _update_colo_tensors(colo_info, ret)
|
||||
if len(res) == 1:
|
||||
return res[0]
|
||||
else:
|
||||
return res
|
||||
|
||||
@staticmethod
|
||||
def has_hook() -> bool:
|
||||
@@ -103,7 +113,7 @@ class PreFwdPostBwd(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, params, *args):
|
||||
ctx.params = params
|
||||
return _unpack_args(args)
|
||||
return args
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *grads):
|
||||
@@ -124,10 +134,29 @@ class PostFwdPreBwd(torch.autograd.Function):
|
||||
return (None,) + grads
|
||||
|
||||
|
||||
def _unpack_args(args):
|
||||
if len(args) == 1:
|
||||
return args[0]
|
||||
return args
|
||||
def _is_grad_tensor(obj) -> bool:
|
||||
if torch.is_tensor(obj):
|
||||
if obj.grad_fn is not None or obj.requires_grad:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _get_grad_args(*args):
|
||||
# returns the identical args if there is a grad tensor
|
||||
for obj in args:
|
||||
if _is_grad_tensor(obj):
|
||||
return args, None
|
||||
# otherwise, the first arguement should be a tuple of grad tensors
|
||||
# if there is no grad tensor, the backward of PreFwdPostBwd can't be triggered
|
||||
arg_zero = args[0]
|
||||
if not isinstance(arg_zero, tuple):
|
||||
raise NotImplementedError("Some torch function is incompatible because of its complcated inputs.")
|
||||
check_grad_flag = False
|
||||
for obj in arg_zero:
|
||||
check_grad_flag |= _is_grad_tensor(obj)
|
||||
if not check_grad_flag:
|
||||
raise NotImplementedError("Some torch function is incompatible because of its complcated inputs.")
|
||||
return arg_zero, args[1:]
|
||||
|
||||
|
||||
def _get_colo_tensors_info(*args) -> list:
|
||||
|
Reference in New Issue
Block a user