[zero] fix error for BEiT models (#2169)

* [zero] fix error for BEiT models

* [ColoParameter] add unpack operation for tuple arguments

* fix bugs

* fix chunkv2 unit testing

* add assertion for gradient state
This commit is contained in:
HELSON
2022-12-26 15:03:54 +08:00
committed by GitHub
parent 4363ff3e41
commit 2458659919
7 changed files with 82 additions and 32 deletions

View File

@@ -8,8 +8,25 @@ from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.tensor.tensor_spec import ColoTensorSpec
def filter_args(func, *args):
return [arg for arg in args if func(arg)]
def filter_colo_parameters(*args, **kwargs):
param_list = []
def get_colo_parameters(element) -> None:
if isinstance(element, list) or isinstance(element, tuple):
for e in element:
get_colo_parameters(e)
elif isinstance(element, dict):
raise RuntimeError("Found Dict: ColoParameter can't deal with complicated arguments.")
elif isinstance(element, ColoParameter):
param_list.append(element)
return
for a in args:
get_colo_parameters(a)
for v in kwargs.values():
get_colo_parameters(v)
return param_list
def replace_args(args, kwargs, new_args):
@@ -62,7 +79,7 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
if not func.__name__.startswith('__'):
if kwargs is None:
kwargs = {}
params = filter_args(lambda arg: isinstance(arg, ColoParameter), *args, *kwargs.values())
params = filter_colo_parameters(*args, **kwargs)
if len(params) > 0:
with torch._C.DisableTorchFunction():
new_args = ColoParamOpHookManager.pre_op(params, *args, *kwargs.values())