mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 04:50:17 +00:00
[autoparallel] adapt autoparallel with new analyzer (#3261)
* [autoparallel] adapt autoparallel with new analyzer * fix all node handler tests * polish * polish
This commit is contained in:
@@ -51,7 +51,10 @@ def _normalize_tuple(x):
|
||||
|
||||
|
||||
def _current_device(module):
|
||||
return next(module.parameters()).device
|
||||
try:
|
||||
return next(module.parameters()).device
|
||||
except StopIteration:
|
||||
return torch.device('cpu')
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
@@ -120,15 +123,18 @@ class ShapeProp(torch.fx.Interpreter):
|
||||
return t.to('meta')
|
||||
|
||||
if isinstance(elem, MetaTensor):
|
||||
if getattr(self, '_is_param', False):
|
||||
return torch.nn.Parameter(_convert_meta(elem._tensor))
|
||||
return _convert_meta(elem._tensor)
|
||||
|
||||
elif isinstance(elem, torch.Tensor):
|
||||
if isinstance(elem, torch.nn.Parameter):
|
||||
return torch.nn.Parameter(_convert_meta(elem))
|
||||
return _convert_meta(elem)
|
||||
|
||||
else:
|
||||
return elem
|
||||
|
||||
# unwrap_fn = lambda elem: elem._tensor if isinstance(elem, MetaTensor) else elem
|
||||
is_pure_tensor = lambda elem: isinstance(elem, MetaTensor) and not isinstance(elem, torch.nn.Parameter)
|
||||
n_info = MetaInfo(n)
|
||||
n_info.outputs = _normalize_tuple(r)
|
||||
@@ -149,7 +155,11 @@ class ShapeProp(torch.fx.Interpreter):
|
||||
n_info.inputs = tuple(v for v in args if is_pure_tensor(v)) + \
|
||||
tuple(v for v in kwargs.values() if is_pure_tensor(v))
|
||||
|
||||
n._meta_data = tree_map(unwrap_fn, _normalize_tuple(r)) # align with SPMD
|
||||
# align with SPMD
|
||||
if isinstance(r, (tuple, list)):
|
||||
n._meta_data = tree_map(unwrap_fn, _normalize_tuple(r))
|
||||
else:
|
||||
n._meta_data = unwrap_fn(r)
|
||||
|
||||
n_info.global_ctx = self.global_hook.ctx
|
||||
n_info.curr_ctx = self.global_hook.ctx.copy()
|
||||
@@ -175,10 +185,48 @@ class ShapeProp(torch.fx.Interpreter):
|
||||
Return
|
||||
Any: The value returned by the function invocation
|
||||
"""
|
||||
convert_to_param = False
|
||||
if target in (torch.transpose, torch.reshape) and isinstance(args[0], torch.nn.parameter.Parameter):
|
||||
convert_to_param = True
|
||||
if target in self._custom_dispatch_func:
|
||||
return self._custom_dispatch_func[target](*args, **kwargs)
|
||||
res = self._custom_dispatch_func[target](*args, **kwargs)
|
||||
else:
|
||||
return super().call_function(target, args, kwargs)
|
||||
res = super().call_function(target, args, kwargs)
|
||||
if convert_to_param:
|
||||
return torch.nn.Parameter(res)
|
||||
else:
|
||||
return res
|
||||
|
||||
def call_method(self, target: 'Target', args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
"""
|
||||
Execute a ``call_method`` node and return the result.
|
||||
|
||||
Args:
|
||||
target (Target): The call target for this node. See
|
||||
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
|
||||
details on semantics
|
||||
args (Tuple): Tuple of positional args for this invocation
|
||||
kwargs (Dict): Dict of keyword arguments for this invocation
|
||||
|
||||
Return
|
||||
Any: The value returned by the method invocation
|
||||
"""
|
||||
# args[0] is the `self` object for this method call
|
||||
self_obj, *args_tail = args
|
||||
|
||||
target_method = getattr(self_obj.__class__, target)
|
||||
|
||||
convert_to_parameter = False
|
||||
if target_method in (torch.Tensor.view, torch.Tensor.transpose) and isinstance(
|
||||
args[0], torch.nn.parameter.Parameter):
|
||||
convert_to_parameter = True
|
||||
# Execute the method and return the result
|
||||
assert isinstance(target, str)
|
||||
res = getattr(self_obj, target)(*args_tail, **kwargs)
|
||||
if convert_to_parameter:
|
||||
return torch.nn.Parameter(res)
|
||||
else:
|
||||
return res
|
||||
|
||||
def propagate(self, *args, device=None):
|
||||
"""
|
||||
|
Reference in New Issue
Block a user