mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 19:40:28 +00:00
[analyzer] a minimal implementation of static graph analyzer (#2852)
* [hotfix] meta tensor default device. * [siu] add experimental submodules to main branch. * [siu] * [siu] * [analyzer] init. * [analyzer] readme. * [analyzer] readme. * [analyzer] readme. * [analyzer] readme. * [test] add test. * Update symbolic_trace.py * mark skip tests. * try except. * try except. * try except. * s * init * init * fix * skip * skip --------- Co-authored-by: Daniel Shao <superdainiu@MININT-PVARVID.fareast.corp.microsoft.com> Co-authored-by: Daniel Shao <superdainiu@Daniels-Mac.local>
This commit is contained in:
194
colossalai/_analyzer/fx/passes/shape_prop.py
Normal file
194
colossalai/_analyzer/fx/passes/shape_prop.py
Normal file
@@ -0,0 +1,194 @@
|
||||
"""``torch.fx.ShapeProp``, but with ``MetaTensor``"""
|
||||
|
||||
from typing import Any, Callable, Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
from torch.autograd.graph import saved_tensors_hooks
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
from colossalai._analyzer._subclasses import MetaTensor, MetaTensorMode
|
||||
from colossalai._analyzer.fx.node_util import MetaInfo
|
||||
from colossalai.fx._compatibility import compatibility
|
||||
|
||||
Target = Union[Callable[..., Any], str]
|
||||
|
||||
|
||||
class sim_env(saved_tensors_hooks):
|
||||
"""
|
||||
A simulation of memory allocation and deallocation in the forward pass
|
||||
using ``saved_tensor_hooks``.
|
||||
|
||||
Attributes:
|
||||
ctx (Dict[int, torch.Tensor]): A dictionary that maps the
|
||||
data pointer of a tensor to the tensor itself. This is used
|
||||
to track the memory allocation and deallocation.
|
||||
|
||||
param_ctx (Dict[int, torch.Tensor]): A dictionary that maps the
|
||||
data pointer of all model parameters to the parameter itself.
|
||||
This avoids overestimating the memory usage of the intermediate activations.
|
||||
"""
|
||||
|
||||
def __init__(self, module: Optional[torch.nn.Module] = None):
|
||||
super().__init__(self.pack_hook, self.unpack_hook)
|
||||
self.ctx = {}
|
||||
self.param_ctx = {param.data_ptr(): param for param in module.parameters()}
|
||||
self.buffer_ctx = {buffer.data_ptr(): buffer for buffer in module.buffers()} if module else {}
|
||||
|
||||
def pack_hook(self, tensor: torch.Tensor):
|
||||
if tensor.data_ptr() not in self.param_ctx and tensor.data_ptr() not in self.buffer_ctx:
|
||||
self.ctx[tensor.data_ptr()] = tensor
|
||||
return tensor
|
||||
|
||||
def unpack_hook(self, tensor):
|
||||
return tensor
|
||||
|
||||
|
||||
def _normalize_tuple(x):
|
||||
if not isinstance(x, tuple):
|
||||
return (x,)
|
||||
return x
|
||||
|
||||
|
||||
def _current_device(module):
|
||||
return next(module.parameters()).device
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
class ShapeProp(torch.fx.Interpreter):
|
||||
"""
|
||||
Execute an FX graph Node-by-Node and record the meta data of the result
|
||||
into the corresponding node.
|
||||
|
||||
Usage:
|
||||
>>> model = MyModule()
|
||||
>>> x = torch.rand(10, 10)
|
||||
>>> gm = colossalai.fx.symbolic_trace(model, meta_args = {'x': x})
|
||||
>>> interp = ShapeProp(gm)
|
||||
>>> interp.propagate(x)
|
||||
|
||||
Args:
|
||||
module (GraphModule): The module to be executed
|
||||
|
||||
Hints:
|
||||
If you want to add a new shape propagation rule, you can do so by
|
||||
adding a new method to this class with the ``@register_shape_impl``
|
||||
decorator. The method should take (*args, **kwargs) instance as its
|
||||
input and generate output.
|
||||
|
||||
For example, if you want to add a shape propagation rule for
|
||||
``torch.nn.functional.linear``, you can do so by adding a new method
|
||||
to this class with the ``@register_shape_impl`` decorator (Since the
|
||||
``MetaTensorMode`` is compatible with ``torch.nn.functional.linear``,
|
||||
in practice you don't have to do as follows):
|
||||
|
||||
>>> @register_shape_impl(torch.nn.functional.linear)
|
||||
>>> def linear_shape_impl(*args, **kwargs):
|
||||
>>> # do something here
|
||||
>>> return torch.empty(output_shape, device=output_device)
|
||||
"""
|
||||
_custom_dispatch_func = {}
|
||||
_mode = MetaTensorMode()
|
||||
|
||||
def __init__(self, module: torch.fx.GraphModule, garbage_collect_values: bool = True):
|
||||
super().__init__(module, garbage_collect_values)
|
||||
self.global_hook = sim_env(module=self.module)
|
||||
|
||||
def run_node(self, n: torch.fx.Node) -> Any:
|
||||
"""
|
||||
Run a specific node ``n`` and return the result. Attach
|
||||
(
|
||||
``inputs``, ``outputs``, ``parameters``, ``buffers``
|
||||
) to ``n``.
|
||||
|
||||
Args:
|
||||
n (Node): The ``Node`` to execute
|
||||
|
||||
Returns:
|
||||
Any: The result of executing ``n``
|
||||
"""
|
||||
args, kwargs = self.fetch_args_kwargs_from_env(n)
|
||||
with self.global_hook:
|
||||
r = getattr(self, n.op)(n.target, args, kwargs)
|
||||
|
||||
unwrap_fn = lambda elem: elem._tensor if isinstance(elem, MetaTensor) else elem
|
||||
is_pure_tensor = lambda elem: isinstance(elem, MetaTensor) and not isinstance(elem, torch.nn.Parameter)
|
||||
n_info = MetaInfo(n)
|
||||
n_info.outputs = _normalize_tuple(r)
|
||||
|
||||
if n.op == 'call_module':
|
||||
submod = self.fetch_attr(n.target)
|
||||
n_info.parameters.update({k: MetaTensor(v) for k, v in submod.named_parameters()})
|
||||
n_info.buffers.update({k: MetaTensor(v) for k, v in submod.named_buffers()})
|
||||
|
||||
else:
|
||||
n_info.parameters.update({
|
||||
k.name: MetaTensor(v)
|
||||
for k, v in zip(n.args, args)
|
||||
if isinstance(k, torch.fx.Node) and isinstance(v, torch.nn.Parameter)
|
||||
})
|
||||
n_info.parameters.update({k: MetaTensor(v) for k, v in kwargs.items() if isinstance(v, torch.nn.Parameter)})
|
||||
|
||||
n_info.inputs = tuple(v for v in args if is_pure_tensor(v)) + \
|
||||
tuple(v for v in kwargs.values() if is_pure_tensor(v))
|
||||
|
||||
n._meta_data = tree_map(unwrap_fn, _normalize_tuple(r)) # align with SPMD
|
||||
|
||||
n_info.global_ctx = self.global_hook.ctx
|
||||
n_info.curr_ctx = self.global_hook.ctx.copy()
|
||||
|
||||
crit = lambda x: x.data_ptr() in self.global_hook.ctx if isinstance(x, torch.Tensor) else False
|
||||
n_info.is_alias = _normalize_tuple(tree_map(crit, n_info.outputs))
|
||||
return r
|
||||
|
||||
def call_function(self, target: 'Target', args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
"""
|
||||
Execute a ``call_function`` node and return the result.
|
||||
If the target of ``Node`` is registered with ``@register_shape_impl``,
|
||||
the registered function will be used to execute the node. This is common
|
||||
if we insert some customized kernels.
|
||||
|
||||
Args:
|
||||
target (Target): The call target for this node. See
|
||||
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
|
||||
details on semantics
|
||||
args (Tuple): Tuple of positional args for this invocation
|
||||
kwargs (Dict): Dict of keyword arguments for this invocation
|
||||
|
||||
Return
|
||||
Any: The value returned by the function invocation
|
||||
"""
|
||||
if target in self._custom_dispatch_func:
|
||||
return self._custom_dispatch_func[target](*args, **kwargs)
|
||||
else:
|
||||
return super().call_function(target, args, kwargs)
|
||||
|
||||
def propagate(self, *args, device=None):
|
||||
"""
|
||||
Run `module` via interpretation and return the result and record the
|
||||
shape of each node.
|
||||
Args:
|
||||
*args (Tensor): The sample input.
|
||||
Returns:
|
||||
Any: The value returned from executing the Module
|
||||
"""
|
||||
wrap_fn = lambda elem: MetaTensor(elem, device=device)
|
||||
with self._mode:
|
||||
return super().run(*tree_map(wrap_fn, args))
|
||||
|
||||
|
||||
def shape_prop_pass(module: torch.fx.GraphModule, *args) -> torch.fx.GraphModule:
|
||||
"""
|
||||
Run ``module`` via interpretation and return the result and record the
|
||||
shape of each ``Node``.
|
||||
|
||||
Args:
|
||||
module (GraphModule): The GraphModule to profile
|
||||
*args (Any): The sample input
|
||||
|
||||
Returns:
|
||||
GraphModule: The same GraphModule with shape information
|
||||
"""
|
||||
|
||||
ShapeProp(module).propagate(*args, device=_current_device(module))
|
||||
return module
|
Reference in New Issue
Block a user