mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 11:32:10 +00:00
[zero] fix error for BEiT models (#2169)
* [zero] fix error for BEiT models * [ColoParameter] add unpack operation for tuple arguments * fix bugs * fix chunkv2 unit testing * add assertion for gradient state
This commit is contained in:
@@ -8,8 +8,25 @@ from colossalai.tensor.param_op_hook import ColoParamOpHookManager
|
||||
from colossalai.tensor.tensor_spec import ColoTensorSpec
|
||||
|
||||
|
||||
def filter_args(func, *args):
|
||||
return [arg for arg in args if func(arg)]
|
||||
def filter_colo_parameters(*args, **kwargs):
|
||||
param_list = []
|
||||
|
||||
def get_colo_parameters(element) -> None:
|
||||
if isinstance(element, list) or isinstance(element, tuple):
|
||||
for e in element:
|
||||
get_colo_parameters(e)
|
||||
elif isinstance(element, dict):
|
||||
raise RuntimeError("Found Dict: ColoParameter can't deal with complicated arguments.")
|
||||
elif isinstance(element, ColoParameter):
|
||||
param_list.append(element)
|
||||
return
|
||||
|
||||
for a in args:
|
||||
get_colo_parameters(a)
|
||||
for v in kwargs.values():
|
||||
get_colo_parameters(v)
|
||||
|
||||
return param_list
|
||||
|
||||
|
||||
def replace_args(args, kwargs, new_args):
|
||||
@@ -62,7 +79,7 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
|
||||
if not func.__name__.startswith('__'):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
params = filter_args(lambda arg: isinstance(arg, ColoParameter), *args, *kwargs.values())
|
||||
params = filter_colo_parameters(*args, **kwargs)
|
||||
if len(params) > 0:
|
||||
with torch._C.DisableTorchFunction():
|
||||
new_args = ColoParamOpHookManager.pre_op(params, *args, *kwargs.values())
|
||||
|
Reference in New Issue
Block a user