mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 03:52:01 +00:00
[gemini] fix param op hook when output is tuple (#5355)
* [gemini] fix param op hook when output is tuple * [gemini] fix param op hook
This commit is contained in:
@@ -7,11 +7,12 @@ from colossalai.tensor.param_op_hook import ColoParamOpHookManager
|
||||
|
||||
from .colo_tensor import _convert_output
|
||||
|
||||
WHITE_LIST_FUNCS = {torch.Tensor.__getitem__, torch.Tensor.is_floating_point}
|
||||
WHITE_LIST_FUNCS = {torch.Tensor.__getitem__}
|
||||
NO_HOOK_FUNCS = {torch.Tensor.is_floating_point}
|
||||
|
||||
|
||||
def is_no_hook_op(func) -> bool:
|
||||
return func.__name__.startswith("__") and func not in WHITE_LIST_FUNCS
|
||||
return (func.__name__.startswith("__") and func not in WHITE_LIST_FUNCS) or func in NO_HOOK_FUNCS
|
||||
|
||||
|
||||
def filter_colo_parameters(*args, **kwargs):
|
||||
|
Reference in New Issue
Block a user