[gemini] fix param op hook when output is tuple (#5355)

* [gemini] fix param op hook when output is tuple

* [gemini] fix param op hook
This commit is contained in:
Hongxin Liu
2024-02-04 11:58:26 +08:00
committed by GitHub
parent 1c790c0877
commit 2dd01e3a14
2 changed files with 8 additions and 5 deletions

View File

@@ -7,11 +7,12 @@ from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from .colo_tensor import _convert_output
WHITE_LIST_FUNCS = {torch.Tensor.__getitem__, torch.Tensor.is_floating_point}
WHITE_LIST_FUNCS = {torch.Tensor.__getitem__}
NO_HOOK_FUNCS = {torch.Tensor.is_floating_point}
def is_no_hook_op(func) -> bool:
return func.__name__.startswith("__") and func not in WHITE_LIST_FUNCS
return (func.__name__.startswith("__") and func not in WHITE_LIST_FUNCS) or func in NO_HOOK_FUNCS
def filter_colo_parameters(*args, **kwargs):