[autoparallel] add getattr handler (#1767)

* [autoparallel] add getattr haandler

* polish code

* add extra processes for Parameters

* add unit test for param resharding cost

* add docstring and polish test
This commit is contained in:
YuliangLiu0306
2022-11-03 12:31:33 +08:00
committed by GitHub
parent c6a1a62636
commit 2c4c7b3618
11 changed files with 306 additions and 37 deletions

View File

@@ -93,17 +93,18 @@ class ColoTracer(Tracer):
origin_arguments = (kind, target, args, kwargs, name, type_expr, proxy_factory_fn)
# dispatch the arguments generator depending on the kind and target in origin arguments.
args_metas, _ = extract_meta(*args, **kwargs)
handle = None
if kind == "call_function":
if bias_addition_function.has(target):
return bias_addition_function.get(target)(self, target, args, kwargs)
handle = bias_addition_function.get(target)(self, target, args, kwargs)
elif bias_addition_function.has(target.__name__):
# use name for some builtin op like @ (matmul)
return bias_addition_function.get(target.__name__)(self, target, args, kwargs)
handle = bias_addition_function.get(target.__name__)(self, target, args, kwargs)
elif kind == "call_method":
method = getattr(args_metas[0].__class__, target)
if bias_addition_function.has(method):
return bias_addition_function.get(method)(self, target, args, kwargs)
handle = bias_addition_function.get(method)(self, target, args, kwargs)
elif kind == "call_module":
if not hasattr(self, "orig_forward"):
@@ -115,10 +116,12 @@ class ColoTracer(Tracer):
if bias_addition_module.has(mod_type) and mod.bias is not None:
function_to_substitute = module_to_func_dict[mod_type]
handle = bias_addition_module.get(mod_type)(self, target, args, kwargs, function_to_substitute)
return handle.generate()
finally:
self._disable_module_getattr = False
if handle is not None:
return handle.generate()
# create nodes using patched arguments
proxy = super().create_proxy(*origin_arguments)
proxy: ColoProxy
@@ -254,7 +257,9 @@ class ColoTracer(Tracer):
atoms = target.split(".")
for atom in atoms:
attr_itr = getattr(attr_itr, atom)
if isinstance(attr_itr, torch.Tensor):
if isinstance(attr_itr, torch.nn.parameter.Parameter):
meta_out = torch.nn.Parameter(attr_itr.to(device="meta"))
elif isinstance(attr_itr, torch.Tensor):
meta_out = attr_itr.to(device="meta")
else:
meta_out = attr_itr