mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 20:10:17 +00:00
[autoparallel] add getattr handler (#1767)
* [autoparallel] add getattr haandler * polish code * add extra processes for Parameters * add unit test for param resharding cost * add docstring and polish test
This commit is contained in:
@@ -93,17 +93,18 @@ class ColoTracer(Tracer):
|
||||
origin_arguments = (kind, target, args, kwargs, name, type_expr, proxy_factory_fn)
|
||||
# dispatch the arguments generator depending on the kind and target in origin arguments.
|
||||
args_metas, _ = extract_meta(*args, **kwargs)
|
||||
handle = None
|
||||
if kind == "call_function":
|
||||
if bias_addition_function.has(target):
|
||||
return bias_addition_function.get(target)(self, target, args, kwargs)
|
||||
handle = bias_addition_function.get(target)(self, target, args, kwargs)
|
||||
elif bias_addition_function.has(target.__name__):
|
||||
# use name for some builtin op like @ (matmul)
|
||||
return bias_addition_function.get(target.__name__)(self, target, args, kwargs)
|
||||
handle = bias_addition_function.get(target.__name__)(self, target, args, kwargs)
|
||||
|
||||
elif kind == "call_method":
|
||||
method = getattr(args_metas[0].__class__, target)
|
||||
if bias_addition_function.has(method):
|
||||
return bias_addition_function.get(method)(self, target, args, kwargs)
|
||||
handle = bias_addition_function.get(method)(self, target, args, kwargs)
|
||||
|
||||
elif kind == "call_module":
|
||||
if not hasattr(self, "orig_forward"):
|
||||
@@ -115,10 +116,12 @@ class ColoTracer(Tracer):
|
||||
if bias_addition_module.has(mod_type) and mod.bias is not None:
|
||||
function_to_substitute = module_to_func_dict[mod_type]
|
||||
handle = bias_addition_module.get(mod_type)(self, target, args, kwargs, function_to_substitute)
|
||||
return handle.generate()
|
||||
finally:
|
||||
self._disable_module_getattr = False
|
||||
|
||||
if handle is not None:
|
||||
return handle.generate()
|
||||
|
||||
# create nodes using patched arguments
|
||||
proxy = super().create_proxy(*origin_arguments)
|
||||
proxy: ColoProxy
|
||||
@@ -254,7 +257,9 @@ class ColoTracer(Tracer):
|
||||
atoms = target.split(".")
|
||||
for atom in atoms:
|
||||
attr_itr = getattr(attr_itr, atom)
|
||||
if isinstance(attr_itr, torch.Tensor):
|
||||
if isinstance(attr_itr, torch.nn.parameter.Parameter):
|
||||
meta_out = torch.nn.Parameter(attr_itr.to(device="meta"))
|
||||
elif isinstance(attr_itr, torch.Tensor):
|
||||
meta_out = attr_itr.to(device="meta")
|
||||
else:
|
||||
meta_out = attr_itr
|
||||
|
Reference in New Issue
Block a user