[autoparallel] adapt autoparallel with new analyzer (#3261)

* [autoparallel] adapt autoparallel with new analyzer

* fix all node handler tests

* polish

* polish
This commit is contained in:
YuliangLiu0306
2023-03-30 17:47:24 +08:00
committed by GitHub
parent e78a1e949a
commit fee2af8610
36 changed files with 481 additions and 386 deletions

View File

@@ -51,7 +51,10 @@ def _normalize_tuple(x):
def _current_device(module):
return next(module.parameters()).device
try:
return next(module.parameters()).device
except StopIteration:
return torch.device('cpu')
@compatibility(is_backward_compatible=False)
@@ -120,15 +123,18 @@ class ShapeProp(torch.fx.Interpreter):
return t.to('meta')
if isinstance(elem, MetaTensor):
if getattr(self, '_is_param', False):
return torch.nn.Parameter(_convert_meta(elem._tensor))
return _convert_meta(elem._tensor)
elif isinstance(elem, torch.Tensor):
if isinstance(elem, torch.nn.Parameter):
return torch.nn.Parameter(_convert_meta(elem))
return _convert_meta(elem)
else:
return elem
# unwrap_fn = lambda elem: elem._tensor if isinstance(elem, MetaTensor) else elem
is_pure_tensor = lambda elem: isinstance(elem, MetaTensor) and not isinstance(elem, torch.nn.Parameter)
n_info = MetaInfo(n)
n_info.outputs = _normalize_tuple(r)
@@ -149,7 +155,11 @@ class ShapeProp(torch.fx.Interpreter):
n_info.inputs = tuple(v for v in args if is_pure_tensor(v)) + \
tuple(v for v in kwargs.values() if is_pure_tensor(v))
n._meta_data = tree_map(unwrap_fn, _normalize_tuple(r)) # align with SPMD
# align with SPMD
if isinstance(r, (tuple, list)):
n._meta_data = tree_map(unwrap_fn, _normalize_tuple(r))
else:
n._meta_data = unwrap_fn(r)
n_info.global_ctx = self.global_hook.ctx
n_info.curr_ctx = self.global_hook.ctx.copy()
@@ -175,10 +185,48 @@ class ShapeProp(torch.fx.Interpreter):
Return
Any: The value returned by the function invocation
"""
convert_to_param = False
if target in (torch.transpose, torch.reshape) and isinstance(args[0], torch.nn.parameter.Parameter):
convert_to_param = True
if target in self._custom_dispatch_func:
return self._custom_dispatch_func[target](*args, **kwargs)
res = self._custom_dispatch_func[target](*args, **kwargs)
else:
return super().call_function(target, args, kwargs)
res = super().call_function(target, args, kwargs)
if convert_to_param:
return torch.nn.Parameter(res)
else:
return res
def call_method(self, target: 'Target', args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_method`` node and return the result.
Args:
target (Target): The call target for this node. See
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
details on semantics
args (Tuple): Tuple of positional args for this invocation
kwargs (Dict): Dict of keyword arguments for this invocation
Return
Any: The value returned by the method invocation
"""
# args[0] is the `self` object for this method call
self_obj, *args_tail = args
target_method = getattr(self_obj.__class__, target)
convert_to_parameter = False
if target_method in (torch.Tensor.view, torch.Tensor.transpose) and isinstance(
args[0], torch.nn.parameter.Parameter):
convert_to_parameter = True
# Execute the method and return the result
assert isinstance(target, str)
res = getattr(self_obj, target)(*args_tail, **kwargs)
if convert_to_parameter:
return torch.nn.Parameter(res)
else:
return res
def propagate(self, *args, device=None):
"""