mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-03 09:14:33 +00:00
[gemini] fix param op hook when output is tuple (#5355)
* [gemini] fix param op hook when output is tuple * [gemini] fix param op hook
This commit is contained in:
parent
1c790c0877
commit
2dd01e3a14
@ -7,11 +7,12 @@ from colossalai.tensor.param_op_hook import ColoParamOpHookManager
|
|||||||
|
|
||||||
from .colo_tensor import _convert_output
|
from .colo_tensor import _convert_output
|
||||||
|
|
||||||
WHITE_LIST_FUNCS = {torch.Tensor.__getitem__, torch.Tensor.is_floating_point}
|
WHITE_LIST_FUNCS = {torch.Tensor.__getitem__}
|
||||||
|
NO_HOOK_FUNCS = {torch.Tensor.is_floating_point}
|
||||||
|
|
||||||
|
|
||||||
def is_no_hook_op(func) -> bool:
|
def is_no_hook_op(func) -> bool:
|
||||||
return func.__name__.startswith("__") and func not in WHITE_LIST_FUNCS
|
return (func.__name__.startswith("__") and func not in WHITE_LIST_FUNCS) or func in NO_HOOK_FUNCS
|
||||||
|
|
||||||
|
|
||||||
def filter_colo_parameters(*args, **kwargs):
|
def filter_colo_parameters(*args, **kwargs):
|
||||||
|
@ -92,7 +92,10 @@ class ColoParamOpHookManager:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def post_op(params: List[torch.Tensor], arg: Any) -> Any:
|
def post_op(params: List[torch.Tensor], arg: Any) -> Any:
|
||||||
ColoParamOpHookManager._trigger_post_forward(params)
|
ColoParamOpHookManager._trigger_post_forward(params)
|
||||||
return PostFwdPreBwd.apply(params, arg)
|
# incase the output is a tuple, we have to flatten it
|
||||||
|
grad_args, other_args, grad_flags, spec = _flatten_grad_args(arg)
|
||||||
|
new_grad_args = PostFwdPreBwd.apply(params, *grad_args)
|
||||||
|
return _merge_args(new_grad_args, other_args, grad_flags, spec)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def has_hook() -> bool:
|
def has_hook() -> bool:
|
||||||
@ -113,7 +116,7 @@ class PreFwdPostBwd(torch.autograd.Function):
|
|||||||
|
|
||||||
class PostFwdPreBwd(torch.autograd.Function):
|
class PostFwdPreBwd(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, params, args):
|
def forward(ctx, params, *args):
|
||||||
ctx.params = params
|
ctx.params = params
|
||||||
return args
|
return args
|
||||||
|
|
||||||
@ -142,7 +145,6 @@ def _flatten_grad_args(args) -> Tuple[list, list, List[bool], TreeSpec]:
|
|||||||
grad_args.append(arg)
|
grad_args.append(arg)
|
||||||
else:
|
else:
|
||||||
other_args.append(arg)
|
other_args.append(arg)
|
||||||
assert len(grad_args) > 0
|
|
||||||
return grad_args, other_args, grad_flags, spec
|
return grad_args, other_args, grad_flags, spec
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user