diff --git a/colossalai/fx/passes/algorithms/ckpt_solver_pofo.py b/colossalai/fx/passes/algorithms/ckpt_solver_pofo.py index b895eb038..841dd19a1 100644 --- a/colossalai/fx/passes/algorithms/ckpt_solver_pofo.py +++ b/colossalai/fx/passes/algorithms/ckpt_solver_pofo.py @@ -10,6 +10,7 @@ from .operation import ForwardCheck, ForwardEnable, ForwardNograd, Backward, Los from colossalai.fx.passes.meta_info_prop import MetaInfoProp from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions from colossalai.fx.passes.algorithms.ckpt_solver_rotor import _construct_chain, _compute_table, _rec +from colossalai import META_COMPATIBILITY INF = float("inf") @@ -507,6 +508,9 @@ def solver_pofo(gm: ColoGraphModule, mem_limit -= parameter_size(gm) # prepare data + if META_COMPATIBILITY: + from colossalai.fx.profiler import MetaTensor + data = MetaTensor(data, fake_device=next(gm.parameters()).device) MetaInfoProp(gm).run(data) chain: Chain = _construct_chain(node_list, data) chain = _normalize_flops(chain, flops) diff --git a/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py b/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py index 2634e4352..0aed73151 100644 --- a/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py +++ b/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py @@ -2,12 +2,12 @@ from typing import List, Tuple from torch.fx import Node from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.profiler import activation_size, parameter_size -from colossalai.fx.profiler.tensor import MetaTensor import math from .linearize import linearize from .operation import ForwardCheck, ForwardEnable, ForwardNograd, Backward, Loss, Chain, Sequence, Function from colossalai.fx.passes.meta_info_prop import MetaInfoProp from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions +from colossalai import META_COMPATIBILITY # this is the python compute table code from rotor @@ -340,7 +340,9 @@ def solver_rotor(gm: ColoGraphModule, node_list = linearize(gm, cnode) mem_unit = mem_limit * (1.0 - eps) // mem_slots - data = MetaTensor(data, fake_device=next(gm.parameters()).device) + if META_COMPATIBILITY: + from colossalai.fx.profiler import MetaTensor + data = MetaTensor(data, fake_device=next(gm.parameters()).device) MetaInfoProp(gm).run(data) chain: Chain = _construct_chain(node_list, data) diff --git a/colossalai/fx/profiler/constant.py b/colossalai/fx/profiler/constant.py index 7219be1ff..d923346fb 100644 --- a/colossalai/fx/profiler/constant.py +++ b/colossalai/fx/profiler/constant.py @@ -14,6 +14,7 @@ if META_COMPATIBILITY: aten.transpose.int, aten.view.default, aten._unsafe_view.default, + aten._reshape_alias.default, ] INPLACE_NEW = [ diff --git a/colossalai/fx/profiler/profiler.py b/colossalai/fx/profiler/profiler.py index dcafc2aa3..563a234d9 100644 --- a/colossalai/fx/profiler/profiler.py +++ b/colossalai/fx/profiler/profiler.py @@ -37,9 +37,28 @@ def detach(x): x.requires_grad_(requires_grad) -def _profile(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], GraphInfo]: +def _profile_concrete(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], GraphInfo]: """ - Profile a Callable function with args and kwargs. + Profile a Callable function with args and kwargs on concrete devices. + + Args: + target (Callable): A Callable function + args (Any): Argument + kwargs (Any): Argument + + Raises: + NotImplementedError: TODO(yby) + + Returns: + out (Tuple[Any, ...]): The argument value that was retrieved. + meta_info (GraphInfo): The memory cost and FLOPs estimated with `MetaTensor`. + """ + raise NotImplementedError + + +def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], GraphInfo]: + """ + Profile a Callable function with args and kwargs on meta devices. Args: target (Callable): A Callable function @@ -67,7 +86,7 @@ def _profile(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], GraphI # Hopefully, this attempt will provide a better estimation of memory. class FlopTensor(MetaTensor): - _node: Node + _node: Node = None def __repr__(self): if self.grad_fn: @@ -76,34 +95,12 @@ def _profile(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], GraphI @classmethod def __torch_dispatch__(cls, func, types, args=(), kwargs=None): - - def get_node(x): - return None if not hasattr(x, '_node') else x._node - - args_node = tree_map(get_node, args) - kwargs_node = tree_map(get_node, kwargs) + args_node = tree_map(lambda x: x._node if isinstance(x, FlopTensor) else None, args) + kwargs_node = tree_map(lambda x: x._node if isinstance(x, FlopTensor) else None, kwargs) node = subgraph.create_node('call_function', func, args_node, kwargs_node) - # do not allocate on physical devices - if 'device' in kwargs: - fake_device = kwargs['device'] - kwargs['device'] = torch.device('meta') + out = super().__torch_dispatch__(func, types, args, kwargs) - def unwrap(x): - nonlocal fake_device - if isinstance(x, MetaTensor): - fake_device = x.device - x = x._tensor - elif isinstance(x, torch.Tensor) and not hasattr(x, '_tensor'): - fake_device = x.device - x = x.to(torch.device('meta')) - return x - - args = tree_map(unwrap, args) - kwargs = tree_map(unwrap, kwargs) - - # run aten for backend=WHATEVER but actually on backend=Meta - out = func(*args, **kwargs) flop_count[phase] += flop_mapping[func](args, normalize_tuple(out)) node.meta['phase'] = phase @@ -114,52 +111,41 @@ def _profile(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], GraphI if all(map(partial(is_phase, phase=Phase.PLACEHOLDER), node.all_input_nodes)) and func in ALIAS_ATEN: node.meta['phase'] = Phase.PLACEHOLDER - # TODO: specify `saved_tensors` for backward memory estimation + # TODO(yby): specify `saved_tensors` for backward memory estimation node.meta['saved_tensor'] = [] if phase == Phase.BACKWARD: node.meta['saved_tensor'] = normalize_tuple(out) def wrap(x): - if isinstance(x, torch.Tensor): - nonlocal fake_device - if not x.is_meta: - x = x.to(torch.device('meta')) - return FlopTensor(x, fake_device=fake_device) if isinstance(x, torch.Tensor) else x - - def set_node(x): - x._node = node + if isinstance(x, MetaTensor): + x = FlopTensor(x) + x._node = node + return x out = tree_map(wrap, out) - tree_map(set_node, out) return out def wrap(x): - fake_device = None - if isinstance(x, MetaTensor): - fake_device = x.device - x = x._tensor - detach(x) - return FlopTensor(x.requires_grad_(True), fake_device=fake_device) if is_autogradable(x) else x - - # Basically, we need to detach the args and kwargs from the outer graph. - args = tree_map(wrap, args) - kwargs = tree_map(wrap, kwargs) - - def set_placeholder(x): - if isinstance(x, FlopTensor): + if isinstance(x, torch.Tensor): + x = FlopTensor(x) + if is_autogradable(x): + x.requires_grad_(True) x._node = subgraph.create_node('placeholder', 'placeholder', (subgraph._root,), name=subgraph._graph_namespace.create_name('input', x._tensor)) x._node.meta['phase'] = Phase.PLACEHOLDER x._node.meta['saved_tensor'] = [] + detach(x) + return x - tree_map(set_placeholder, args) - tree_map(set_placeholder, kwargs) + # Basically, we need to detach the args and kwargs from the outer graph. + args = tree_map(wrap, args) + kwargs = tree_map(wrap, kwargs) def pack(x): global cache if isinstance(x, FlopTensor) and not x._tensor.data_ptr in cache: - x._node.meta['saved_tensor'] += [x._tensor] + x._node.meta['saved_tensor'] += [x] cache.add(x._tensor.data_ptr) return x @@ -191,16 +177,12 @@ def _profile(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], GraphI graph_info.fwd_mem_out = activation_size(out) def unwrap(x): - if isinstance(x, FlopTensor): - fake_device = x.device - x = x._tensor - detach(x) - return MetaTensor(x, fake_device=fake_device) if isinstance(x, torch.Tensor) else x + return MetaTensor(x) if isinstance(x, torch.Tensor) else x return tree_map(unwrap, out), graph_info -def profile_function(target: 'Target') -> Callable: +def profile_function(target: 'Target', device: str = 'meta') -> Callable: """ Wrap a `call_function` node or `torch.nn.functional` in order to record the memory cost and FLOPs of the execution. @@ -222,7 +204,10 @@ def profile_function(target: 'Target') -> Callable: inplace = kwargs.get('inplace', False) if inplace: kwargs['inplace'] = False - out, meta = _profile(func, *args, **kwargs) + if device == 'meta': + out, meta = _profile_meta(func, *args, **kwargs) + else: + out, meta = _profile_concrete(func, *args, **kwargs) if inplace: if target in [torch.nn.functional.relu]: meta.save_fwd_in = False @@ -234,7 +219,7 @@ def profile_function(target: 'Target') -> Callable: return f -def profile_method(target: 'Target') -> Callable: +def profile_method(target: 'Target', device: str = 'meta') -> Callable: """ Wrap a `call_method` node record the memory cost and FLOPs of the execution. @@ -243,13 +228,16 @@ def profile_method(target: 'Target') -> Callable: def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: # execute the method and return the result assert isinstance(target, str), f'{target} instance is not str.' - out, meta = _profile(target, *args, **kwargs) + if device == 'meta': + out, meta = _profile_meta(target, *args, **kwargs) + else: + out, meta = _profile_concrete(target, *args, **kwargs) return out, meta return f -def profile_module(module: torch.nn.Module) -> Callable: +def profile_module(module: torch.nn.Module, device: str = 'meta') -> Callable: """ Wrap a `call_module` node or `torch.nn` in order to record the memory cost and FLOPs of the execution. @@ -271,7 +259,10 @@ def profile_module(module: torch.nn.Module) -> Callable: inplace = getattr(module, 'inplace', False) if inplace: module.inplace = False - out, meta = _profile(func, *args, **kwargs) + if device == 'meta': + out, meta = _profile_meta(func, *args, **kwargs) + else: + out, meta = _profile_concrete(func, *args, **kwargs) if inplace: # super-dainiu: experiments on mobilenet_v2 shows that `torch.nn.ReLU` # is the only inplace activation function that discard its input. diff --git a/colossalai/fx/profiler/tensor.py b/colossalai/fx/profiler/tensor.py index 45d170437..173eb81d9 100644 --- a/colossalai/fx/profiler/tensor.py +++ b/colossalai/fx/profiler/tensor.py @@ -1,5 +1,9 @@ +from copy import deepcopy +from typing import Optional, Union, overload import torch from torch.utils._pytree import tree_map, tree_flatten +from torch.types import _bool, _dtype, _device +from functools import singledispatchmethod __all__ = ['MetaTensor'] @@ -16,6 +20,11 @@ class MetaTensor(torch.Tensor): @staticmethod def __new__(cls, elem, fake_device=None): + # Avoid multiple wrapping + if isinstance(elem, MetaTensor): + fake_device = elem.device if fake_device is None else fake_device + elem = elem._tensor + # The wrapping tensor (MetaTensor) shouldn't hold any # memory for the class in question, but it should still # advertise the same device as before @@ -74,3 +83,32 @@ class MetaTensor(torch.Tensor): return MetaTensor(x, fake_device=fake_device) if isinstance(x, torch.Tensor) else x return tree_map(wrap, out) + + @singledispatchmethod + def to(self, *args, **kwargs) -> torch.Tensor: + """An extension of `torch.Tensor.to()` to MetaTensor + + Returns: + result (MetaTensor): MetaTensor + + Usage: + >>> tensor = MetaTensor(torch.rand(10), fake_device='cuda:100') + >>> tensor.to(torch.uint8) + MetaTensor(tensor(..., device='meta', size=(10,), dtype=torch.uint8), fake_device='cuda:100') + >>> tensor.to(torch.device('cuda:42')) + MetaTensor(tensor(..., device='meta', size=(10,)), fake_device='cuda:42') + >>> tensor.to('vulkan') + MetaTensor(tensor(..., device='meta', size=(10,)), fake_device='vulkan') + """ + # this imitates c++ function in the way of @overload + return super().to(*args, **kwargs) + + @to.register + def _(self, device: str, dtype: Optional[_dtype] = None, non_blocking: _bool = False, copy: _bool = False) -> torch.Tensor: + result = super().to(dtype, non_blocking, copy) if dtype is not None else self + return MetaTensor(deepcopy(result), fake_device=device) + + @to.register + def _(self, device: _device, dtype: Optional[_dtype] = None, non_blocking: _bool = False, copy: _bool = False) -> torch.Tensor: + result = super().to(dtype, non_blocking, copy) if dtype is not None else self + return MetaTensor(deepcopy(result), fake_device=device)