From d967779a320134b439729fddb1b487bdee95474d Mon Sep 17 00:00:00 2001 From: Super Daniel <78588128+super-dainiu@users.noreply.github.com> Date: Fri, 23 Sep 2022 10:59:47 +0800 Subject: [PATCH] [fx/profiler] tuned the calculation of memory estimation (#1619) * [fx] tuned the meta info and rotor solver. * [fx] remove import. * [fx] remove import. * [fx] remove import. * [fx] tune the meta calculations. * [fx] polish comments. * [fx] remove assertions. * [fx] modify test cases. * [fx] modify test cases. * [fx] optimize import. * [fx --- colossalai/_meta_registrations.py | 56 ++++++- .../fx/passes/algorithms/ckpt_solver_rotor.py | 12 +- colossalai/fx/passes/meta_info_prop.py | 6 +- colossalai/fx/profiler/constant.py | 78 ++++++++++ colossalai/fx/profiler/dataflow.py | 54 +++---- .../fx/profiler/experimental/profiler.py | 3 +- colossalai/fx/profiler/memory.py | 88 +---------- colossalai/fx/profiler/opcount.py | 15 +- colossalai/fx/profiler/profiler.py | 141 ++++++++++++------ colossalai/fx/profiler/tensor.py | 34 ++++- colossalai/fx/tracer/_meta_trace.py | 54 +++++-- tests/test_fx/test_comm_size_compute.py | 3 +- tests/test_fx/test_meta/test_aten.py | 7 +- tests/test_fx/test_meta/test_backward.py | 12 +- tests/test_fx/test_meta/test_meta_trace.py | 48 ++++++ tests/test_fx/test_meta_info_prop.py | 9 +- 16 files changed, 413 insertions(+), 207 deletions(-) create mode 100644 colossalai/fx/profiler/constant.py create mode 100644 tests/test_fx/test_meta/test_meta_trace.py diff --git a/colossalai/_meta_registrations.py b/colossalai/_meta_registrations.py index 802150ded..4e58c61c4 100644 --- a/colossalai/_meta_registrations.py +++ b/colossalai/_meta_registrations.py @@ -175,6 +175,11 @@ def meta_hardswish(input: torch.Tensor): return torch.empty_like(input) +@register_meta(aten.hardtanh.default) +def meta_hardtanh(input: torch.Tensor, min, max): + return torch.empty_like(input) + + @register_meta(aten.hardswish_backward.default) def meta_hardswish_backward(grad_out: torch.Tensor, input: torch.Tensor): grad_in = torch.empty_like(input) @@ -189,7 +194,7 @@ def meta_hardtanh_backward(grad_out: torch.Tensor, input: torch.Tensor, min_val: @register_meta(aten.roll.default) def meta_roll(input: torch.Tensor, shifts, dims): - return torch.empty_like(input) + return input @register_meta(aten.native_batch_norm.default) @@ -211,13 +216,39 @@ def meta_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor return dX, dgamma, dbeta -@register_meta(aten.native_layer_norm.default) -def meta_ln(input: torch.Tensor, normalized_shape, weight, bias, eps): +# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp +@register_meta(aten.cudnn_batch_norm.default) +def meta_cudnn_bn(input: torch.Tensor, weight, bias, running_mean, running_var, training, momentum, eps): n_input = input.size(1) output = torch.empty_like(input) running_mean = torch.empty((n_input), device='meta') running_var = torch.empty((n_input), device='meta') + reserve = torch.empty((0), dtype=torch.uint8, device='meta') + return output, running_mean, running_var, reserve + + +# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp +# NB: CuDNN only implements the backward algorithm for batchnorm +# in training mode (evaluation mode batchnorm has a different algorithm), +# which is why this doesn't accept a 'training' parameter. +@register_meta(aten.cudnn_batch_norm_backward.default) +def meta_cudnn_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var, + save_mean, save_invstd, eps, reserve): + dX = torch.empty_like(input) + dgamma = torch.empty_like(weight) + dbeta = torch.empty_like(weight) + return dX, dgamma, dbeta + + +@register_meta(aten.native_layer_norm.default) +def meta_ln(input: torch.Tensor, normalized_shape, weight, bias, eps): + bs = input.size(0) + n_input = input.size(1) + + output = torch.empty_like(input) + running_mean = torch.empty((bs, n_input, 1), device='meta') + running_var = torch.empty((bs, n_input, 1), device='meta') return output, running_mean, running_var @@ -338,6 +369,23 @@ def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tens layout=grad_output.layout) +# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorCompare.cpp @register_meta(aten.where.self) def meta_where_self(condition: torch.Tensor, self: torch.Tensor, other: torch.Tensor): - return torch.empty_like(condition) + result_type = torch.result_type(self, other) + return torch.empty_like(self, dtype=result_type) + + +# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp +@register_meta(aten.native_dropout.default) +def meta_native_dropout_default(input: torch.Tensor, p: float, train: bool = False): + # notice that mask is bool + output = torch.empty_like(input) + mask = torch.empty_like(input, dtype=torch.bool) + return output, mask + + +# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp +@register_meta(aten.native_dropout_backward.default) +def meta_native_dropout_backward_default(grad: torch.Tensor, mask: torch.Tensor, scale: float): + return torch.empty_like(grad) diff --git a/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py b/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py index f9991c407..2634e4352 100644 --- a/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py +++ b/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py @@ -2,6 +2,7 @@ from typing import List, Tuple from torch.fx import Node from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.profiler import activation_size, parameter_size +from colossalai.fx.profiler.tensor import MetaTensor import math from .linearize import linearize from .operation import ForwardCheck, ForwardEnable, ForwardNograd, Backward, Loss, Chain, Sequence, Function @@ -123,7 +124,9 @@ def _fwd_xbar(node: List[Node]) -> int: xbar = 0 for n in node: - xbar += n.meta['fwd_mem_tmp'] + n.meta['fwd_mem_out'] + xbar += n.meta['fwd_mem_tmp'] + if any(map(lambda x: x.meta['save_fwd_in'], n.users)): + xbar += n.meta['fwd_mem_out'] return xbar @@ -177,10 +180,13 @@ def _get_bwd_mem_tmp(node: List[Node]) -> int: def _get_deps_size(): deps_size = 0 for k, v in deps.items(): + k: Node if v > 0: deps_size += k.meta['bwd_mem_out'] if v == float('-inf'): - deps_size -= k.meta['fwd_mem_tmp'] + k.meta['fwd_mem_out'] + deps_size -= k.meta['fwd_mem_tmp'] + if any(map(lambda x: x.meta['save_fwd_in'], k.users)): + deps_size -= k.meta['fwd_mem_out'] return deps_size @@ -333,8 +339,8 @@ def solver_rotor(gm: ColoGraphModule, """ node_list = linearize(gm, cnode) - mem_limit -= parameter_size(gm) mem_unit = mem_limit * (1.0 - eps) // mem_slots + data = MetaTensor(data, fake_device=next(gm.parameters()).device) MetaInfoProp(gm).run(data) chain: Chain = _construct_chain(node_list, data) diff --git a/colossalai/fx/passes/meta_info_prop.py b/colossalai/fx/passes/meta_info_prop.py index 84efca13a..170176d71 100644 --- a/colossalai/fx/passes/meta_info_prop.py +++ b/colossalai/fx/passes/meta_info_prop.py @@ -94,11 +94,9 @@ class MetaInfoProp(torch.fx.Interpreter): tensor_meta = tree_map(extract_tensor_meta, result) n.meta['tensor_meta'] = tensor_meta - n.meta = {**n.meta, **asdict(meta_info), 'fwd_mem_out': 0} # extend MetaInfo to `n.meta` + n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta` # TODO: the attribute node_size should be removed in the future setattr(n, 'node_size', n.meta.get('fwd_mem_tmp', 0) + n.meta.get('fwd_mem_out', 0)) - for par in n.all_input_nodes: - par.meta['fwd_mem_out'] = max(par.meta.get('fwd_mem_out', 0), n.meta.get('fwd_mem_in', 0)) n.meta['type'] = type(result) # retain the autograd graph @@ -224,7 +222,7 @@ class MetaInfoProp(torch.fx.Interpreter): result (Any): The argument value that was retrieved meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`. """ - return args[0], GraphInfo(fwd_mem_in=activation_size(args[0])) + return args[0], GraphInfo(save_fwd_in=True) def propagate(self, *args): """ diff --git a/colossalai/fx/profiler/constant.py b/colossalai/fx/profiler/constant.py new file mode 100644 index 000000000..7219be1ff --- /dev/null +++ b/colossalai/fx/profiler/constant.py @@ -0,0 +1,78 @@ +import torch +from operator import add, floordiv, getitem, mul, neg, setitem, sub, pos +from . import META_COMPATIBILITY + +__all__ = [] + +if META_COMPATIBILITY: + aten = torch.ops.aten + + ALIAS_ATEN = [ + # inplace reshaping + aten.detach.default, + aten.t.default, + aten.transpose.int, + aten.view.default, + aten._unsafe_view.default, + ] + + INPLACE_NEW = [ + aten.empty_like.default, + aten.new_empty_strided.default, + ] + + INPLACE_MATH_ATEN = [ + aten.add_.Tensor, + aten.sub_.Tensor, + aten.div_.Tensor, + aten.div_.Scalar, + aten.mul_.Tensor, + aten.bernoulli_.float, + ] + + CLONE_ATEN = [ + aten.clone.default, + ] + + __all__ += ['INPLACE_ATEN', 'INPLACE_MATH_ATEN', 'CLONE_ATEN'] + +else: + # TODO fill out the inplace ops + INPLACE_OPS = [ + add, + sub, + mul, + floordiv, + neg, + pos, + getitem, + setitem, + getattr, + torch.Tensor.cpu, + ] + + # TODO: list all call_methods that are inplace here + INPLACE_METHOD = [ + 'transpose', + 'permute', + # TODO: reshape may return a copy of the data if the data is not contiguous + 'reshape', + 'dim', + 'flatten', + 'size', + 'view', + 'unsqueeze', + 'to', + 'type', + 'flatten', + ] + + # TODO: list all call_methods that are not inplace here + NON_INPLACE_METHOD = [ + 'chunk', + 'contiguous', + 'expand', + 'mean', + 'split', + ] + __all__ += ['INPLACE_OPS', 'INPLACE_METHOD', 'NON_INPLACE_METHOD'] diff --git a/colossalai/fx/profiler/dataflow.py b/colossalai/fx/profiler/dataflow.py index 2b4b6c17e..0551f6e25 100644 --- a/colossalai/fx/profiler/dataflow.py +++ b/colossalai/fx/profiler/dataflow.py @@ -3,9 +3,6 @@ from enum import Enum from typing import Dict from torch.fx import Graph, Node from .memory import activation_size, is_inplace -from . import META_COMPATIBILITY -if META_COMPATIBILITY: - from .memory import NORMALIZATION_ATEN, CLONE_ATEN class Phase(Enum): @@ -23,29 +20,32 @@ class GraphInfo: ============================================================================ ------------------------------- | Node | - [fwd_in] are ---> | [fwd_in] [bwd_out] | <----- [bwd_out] is marks the memory for `grad_out` + [fwd_in] are ---> | [fwd_in] [bwd_out] | <----- [bwd_out] is marks the memory for `grad_out`. placeholders saved for | | \__________ | | backward. | | \ | | | [fwd_tmp] ------> [bwd_tmp] | <----- | | \_________ | | [bwd_tmp] marks the peak memory | / \ \ | | in backward pass. [x] is not counted ---> | [x] [fwd_tmp] -> [bwd_tmp] | <----- - in [fwd_tmp] because | | | \_____ | | - it is not saved for | | | \ | | - backward. ------------------------------- + in [fwd_tmp] because | | \_____ | | + it is not saved for | | \ | | + backward. | [fwd_out] \ | | <----- [fwd_out] is [fwd_in] for the next node. + ------------------------------- ============================================================================ Attributes: fwd_flop (int): The forward FLOPs of a certain node bwd_flop (int): The backward FLOPs of a certain node. - fwd_mem_in (int): See the above illustration. + save_fwd_in (bool): The decision variable of whether to save the fwd_mem_out of parent nodes. fwd_mem_tmp (int): See the above illustration. + fwd_mem_out (int): See the above illustration. bwd_mem_tmp (int): See the above illustration. bwd_mem_out (int): See the above illustration. """ fwd_flop: int = 0 bwd_flop: int = 0 - fwd_mem_in: int = 0 + save_fwd_in: bool = False fwd_mem_tmp: int = 0 + fwd_mem_out: int = 0 bwd_mem_tmp: int = 0 bwd_mem_out: int = 0 @@ -56,7 +56,7 @@ def is_phase(n: Node, phase: Phase) -> bool: def is_saved(n: Node): - return n.meta.get('saved', False) + return len(n.meta['saved_tensor']) def autograd_graph_analysis(graph: Graph) -> GraphInfo: @@ -87,10 +87,10 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo: def _peak_memory(deps: Dict[Node, int]): peak_mem = 0 for k, v in deps.items(): - if v > 0 and is_phase(k, Phase.BACKWARD) and not any(map(is_inplace, k.users)): - peak_mem += activation_size(k.meta['out']) - if v <= float('-inf') and is_saved(k) and (k.target not in NORMALIZATION_ATEN): - peak_mem -= activation_size(k.meta['out']) + if v > 0 and is_phase(k, Phase.BACKWARD) and not all(map(is_inplace, k.users)) and not is_inplace(k): + peak_mem += activation_size(k.meta['saved_tensor']) + if v <= float('-inf') and is_phase(k, Phase.FORWARD): + peak_mem -= activation_size(k.meta['saved_tensor']) return peak_mem # deps is used to track all the memory dependencies of the graph. @@ -99,25 +99,25 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo: for n in graph.nodes: n: Node - if is_saved(n) and (n.target not in NORMALIZATION_ATEN) or any(map(lambda x: x.target in CLONE_ATEN, n.users)): - # A forward tensor who is marked `save` but is not - # an input to `loss` should be saved during forward. - # If the tensor is a placeholder, then it belongs to `fwd_mem_in`. - # Any `fwd_mem_in` should be kept in memory even this function - # is checkpointed. - # Otherwise, the tensor belongs to `fwd_mem_tmp`. If we checkpoint - # the node, `fwd_mem_tmp` can be freed. - if is_phase(n, Phase.PLACEHOLDER): - graph_info.fwd_mem_in += activation_size(n.meta['out']) - if is_phase(n, Phase.FORWARD): - graph_info.fwd_mem_tmp += activation_size(n.meta['out']) + deps[n] = len(n.users) + # A forward tensor who is marked `save` but is also + # an input to `Phase.FORWARD` should be saved during forward. + # If the tensor is a placeholder, then it belongs to `fwd_mem_in`. + # Any `fwd_mem_in` should be kept in memory even this function + # is checkpointed. + # Otherwise, the tensor belongs to `fwd_mem_tmp`. If we checkpoint + # the node, `fwd_mem_tmp` can be freed. + if is_phase(n, Phase.PLACEHOLDER): + graph_info.save_fwd_in |= activation_size(n.meta['saved_tensor']) > 0 + if is_phase(n, Phase.FORWARD): + graph_info.fwd_mem_tmp += activation_size(n.meta['saved_tensor']) elif is_phase(n, Phase.BACKWARD): if len(n.users): graph_info.bwd_mem_tmp = max(graph_info.bwd_mem_tmp, _peak_memory(deps)) else: # TODO: some of the bwd_mem_out might be model parameters. # basically a backward node without user is a `grad_out` node - graph_info.bwd_mem_out += activation_size(n.meta['out']) + graph_info.bwd_mem_out += activation_size(n.meta['saved_tensor']) for input_n in n.all_input_nodes: if input_n in deps: deps[input_n] -= 1 diff --git a/colossalai/fx/profiler/experimental/profiler.py b/colossalai/fx/profiler/experimental/profiler.py index 954e8b49b..c7c3f81dd 100644 --- a/colossalai/fx/profiler/experimental/profiler.py +++ b/colossalai/fx/profiler/experimental/profiler.py @@ -3,7 +3,8 @@ from typing import Callable, Any, Dict, Tuple import torch from torch.fx.node import Argument, Target from . import meta_profiler_function, meta_profiler_module -from ..memory import activation_size, INPLACE_METHOD, NON_INPLACE_METHOD, INPLACE_OPS +from ..memory import activation_size +from ..constant import INPLACE_METHOD, NON_INPLACE_METHOD, INPLACE_OPS __all__ = ['profile_function', 'profile_module', 'profile_method'] diff --git a/colossalai/fx/profiler/memory.py b/colossalai/fx/profiler/memory.py index 2e0f1a058..96233a9df 100644 --- a/colossalai/fx/profiler/memory.py +++ b/colossalai/fx/profiler/memory.py @@ -1,88 +1,10 @@ import torch from torch.fx import Node from typing import Union, Dict, List, Tuple -from operator import add, floordiv, getitem, mul, neg, setitem, sub, pos from . import META_COMPATIBILITY __all__ = ['activation_size', 'parameter_size', 'is_inplace'] -if META_COMPATIBILITY: - aten = torch.ops.aten - - WEIRD_OPS = [ - torch.where, - ] - - INPLACE_ATEN = [ - aten.add_.Tensor, - aten.sub_.Tensor, - aten.div_.Tensor, - aten.div_.Scalar, - aten.mul_.Tensor, - aten.bernoulli_.float, - - # inplace reshaping - aten.copy_.default, - aten.detach.default, - aten.t.default, - aten.transpose.int, - aten.view.default, - aten._unsafe_view.default, - ] - - NORMALIZATION_ATEN = [ - aten.native_batch_norm.default, - aten.native_layer_norm.default, - # aten.max_pool2d_with_indices.default, - ] - - CLONE_ATEN = [ - aten.clone.default, - ] - - __all__ += ['INPLACE_ATEN', 'WEIRD_OPS', 'NORMALIZATION_ATEN', 'CLONE_ATEN'] - -else: - # TODO fill out the inplace ops - INPLACE_OPS = [ - add, - sub, - mul, - floordiv, - neg, - pos, - getitem, - setitem, - getattr, - torch.Tensor.cpu, - ] - - # TODO: list all call_methods that are inplace here - INPLACE_METHOD = [ - 'transpose', - 'permute', - # TODO: reshape may return a copy of the data if the data is not contiguous - 'reshape', - 'dim', - 'flatten', - 'size', - 'view', - 'unsqueeze', - 'to', - 'type', - 'flatten', - ] - - # TODO: list all call_methods that are not inplace here - NON_INPLACE_METHOD = [ - 'chunk', - 'contiguous', - 'expand', - 'mean', - 'split', - ] - __all__ += ['INPLACE_OPS', 'INPLACE_METHOD', 'NON_INPLACE_METHOD'] - def activation_size(out: Union[torch.Tensor, Dict, List, Tuple, int]) -> int: """Calculate activation size of a node. @@ -106,13 +28,13 @@ def activation_size(out: Union[torch.Tensor, Dict, List, Tuple, int]) -> int: def parameter_size(mod: torch.nn.Module) -> int: - """Calculate param size of a node. + """Calculate parameter size of a node. Args: mod (torch.nn.Module): The target `torch.nn.Module` Returns: - int: The param size + int: The parameter size """ param_size = 0 for param in mod.parameters(): @@ -132,8 +54,10 @@ def is_inplace(n: Node): inplace = False if n.op == "call_function": inplace = n.kwargs.get("inplace", False) - if META_COMPATIBILITY and n.target in INPLACE_ATEN: - inplace = True + if META_COMPATIBILITY: + from .constant import ALIAS_ATEN + if n.target in ALIAS_ATEN: + inplace = True elif n.op == "call_module": inplace = getattr(n.graph.owning_module.get_submodule(n.target), "inplace", False) diff --git a/colossalai/fx/profiler/opcount.py b/colossalai/fx/profiler/opcount.py index 4d51e0eea..3e2662eef 100644 --- a/colossalai/fx/profiler/opcount.py +++ b/colossalai/fx/profiler/opcount.py @@ -1,7 +1,7 @@ # adopted from https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/jit_handles.py # ideas from https://pastebin.com/AkvAyJBw -from functools import reduce +from functools import partial, reduce import operator from typing import Callable, List, Any from numbers import Number @@ -147,8 +147,9 @@ def norm_flop_counter(affine_arg_index: int, input_arg_index: int) -> Callable: return norm_flop_jit -def batchnorm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number: - training = inputs[-3] +def batchnorm_flop_jit(inputs: List[Any], outputs: List[Any], training: bool = None) -> Number: + if training is None: + training = inputs[-3] assert isinstance(training, bool), "Signature of aten::batch_norm has changed!" if training: return norm_flop_counter(1, 0)(inputs, outputs) # pyre-ignore @@ -201,6 +202,8 @@ flop_mapping = { # normalization aten.native_batch_norm.default: batchnorm_flop_jit, aten.native_batch_norm_backward.default: batchnorm_flop_jit, + aten.cudnn_batch_norm.default: batchnorm_flop_jit, + aten.cudnn_batch_norm_backward.default: partial(batchnorm_flop_jit, training=True), aten.native_layer_norm.default: norm_flop_counter(2, 0), aten.native_layer_norm_backward.default: norm_flop_counter(2, 0), @@ -247,12 +250,14 @@ elementwise_flop_aten = [ aten.hardswish.default, aten.hardswish_.default, aten.hardswish_backward.default, + aten.hardtanh.default, aten.hardtanh_.default, aten.hardtanh_backward.default, aten.hardsigmoid_backward.default, aten.hardsigmoid.default, aten.gelu.default, aten.gelu_backward.default, + aten.silu.default, aten.silu_.default, aten.silu_backward.default, aten.sigmoid.default, @@ -264,6 +269,10 @@ elementwise_flop_aten = [ aten.tanh.default, aten.tanh_backward.default, aten.threshold_backward.default, + + # dropout + aten.native_dropout.default, + aten.native_dropout_backward.default, ] for op in elementwise_flop_aten: diff --git a/colossalai/fx/profiler/profiler.py b/colossalai/fx/profiler/profiler.py index 8051a753c..dcafc2aa3 100644 --- a/colossalai/fx/profiler/profiler.py +++ b/colossalai/fx/profiler/profiler.py @@ -1,15 +1,21 @@ +from functools import partial from typing import Callable, Any, Dict, Tuple import torch from torch.fx import Graph, Node from torch.fx.node import Argument, Target from torch.utils._pytree import tree_map -from .dataflow import GraphInfo, autograd_graph_analysis, Phase -from .memory import WEIRD_OPS +from .dataflow import autograd_graph_analysis, is_phase, Phase, GraphInfo +from .memory import activation_size +from .constant import ALIAS_ATEN from .tensor import MetaTensor from .opcount import flop_mapping __all__ = ['profile_function', 'profile_module', 'profile_method'] +# super-dainiu: this cache should be global, otherwise it cannot +# track duplicated tensors between nodes +cache = set() + def normalize_tuple(x): if not isinstance(x, tuple): @@ -21,7 +27,17 @@ def is_autogradable(x): return isinstance(x, torch.Tensor) and x.is_floating_point() -def _profile(target: Callable, *args, **kwargs) -> Tuple[Any, ...]: +# super-dainiu: +# x.detach() will change the unique identifier of data_ptr +# we need to handle this in a stupid way +def detach(x): + if isinstance(x, torch.Tensor): + requires_grad = x.requires_grad + x.requires_grad_(False) + x.requires_grad_(requires_grad) + + +def _profile(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], GraphInfo]: """ Profile a Callable function with args and kwargs. @@ -55,8 +71,8 @@ def _profile(target: Callable, *args, **kwargs) -> Tuple[Any, ...]: def __repr__(self): if self.grad_fn: - return f"FlopTensor(..., device={self._tensor.device}, size={tuple(self.shape)}, grad_fn={self.grad_fn})" - return f"FlopTensor(..., device={self._tensor.device}, size={tuple(self.shape)})" + return f"FlopTensor({self._tensor}, fake_device='{self.device}', size={tuple(self.shape)}, grad_fn={self.grad_fn})" + return f"FlopTensor({self._tensor}, fake_device='{self.device}', size={tuple(self.shape)}, requires_grad={self.requires_grad})" @classmethod def __torch_dispatch__(cls, func, types, args=(), kwargs=None): @@ -68,27 +84,47 @@ def _profile(target: Callable, *args, **kwargs) -> Tuple[Any, ...]: kwargs_node = tree_map(get_node, kwargs) node = subgraph.create_node('call_function', func, args_node, kwargs_node) - # do not allocate on `cpu` + # do not allocate on physical devices if 'device' in kwargs: - kwargs['device'] = 'meta' + fake_device = kwargs['device'] + kwargs['device'] = torch.device('meta') def unwrap(x): - # if x is a `nn.Parameter`, we can first wrap it with `FlopTensor` - if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor'): - x = FlopTensor(x.to('meta')) - return x._tensor.to('meta') if isinstance(x, FlopTensor) else x + nonlocal fake_device + if isinstance(x, MetaTensor): + fake_device = x.device + x = x._tensor + elif isinstance(x, torch.Tensor) and not hasattr(x, '_tensor'): + fake_device = x.device + x = x.to(torch.device('meta')) + return x args = tree_map(unwrap, args) kwargs = tree_map(unwrap, kwargs) - # run aten for backend=CPU but actually on backend=Meta + # run aten for backend=WHATEVER but actually on backend=Meta out = func(*args, **kwargs) flop_count[phase] += flop_mapping[func](args, normalize_tuple(out)) - node.meta['out'] = normalize_tuple(out) node.meta['phase'] = phase + # super-dainiu: in `nn.MultiheadAttention` this weird thing occurs, + # i.e. `Phase.PLACEHOLDER` tensors are aliased and saved during + # `Phase.FORWARD` + if phase == Phase.FORWARD: + if all(map(partial(is_phase, phase=Phase.PLACEHOLDER), node.all_input_nodes)) and func in ALIAS_ATEN: + node.meta['phase'] = Phase.PLACEHOLDER + + # TODO: specify `saved_tensors` for backward memory estimation + node.meta['saved_tensor'] = [] + if phase == Phase.BACKWARD: + node.meta['saved_tensor'] = normalize_tuple(out) + def wrap(x): - return FlopTensor(x.to('meta')) if isinstance(x, torch.Tensor) else x + if isinstance(x, torch.Tensor): + nonlocal fake_device + if not x.is_meta: + x = x.to(torch.device('meta')) + return FlopTensor(x, fake_device=fake_device) if isinstance(x, torch.Tensor) else x def set_node(x): x._node = node @@ -97,18 +133,13 @@ def _profile(target: Callable, *args, **kwargs) -> Tuple[Any, ...]: tree_map(set_node, out) return out - # `WEIRD_OPS` are tough to handle because they don't accept autograd - # on meta tensor. - if target not in WEIRD_OPS: - - def wrap(x): - return FlopTensor( - x.detach().requires_grad_(True)) if is_autogradable(x) and not hasattr(x, '_tensor') else x - else: - - def wrap(x): - return FlopTensor( - x.detach().requires_grad_(False)) if is_autogradable(x) and not hasattr(x, '_tensor') else x + def wrap(x): + fake_device = None + if isinstance(x, MetaTensor): + fake_device = x.device + x = x._tensor + detach(x) + return FlopTensor(x.requires_grad_(True), fake_device=fake_device) if is_autogradable(x) else x # Basically, we need to detach the args and kwargs from the outer graph. args = tree_map(wrap, args) @@ -120,14 +151,16 @@ def _profile(target: Callable, *args, **kwargs) -> Tuple[Any, ...]: 'placeholder', (subgraph._root,), name=subgraph._graph_namespace.create_name('input', x._tensor)) x._node.meta['phase'] = Phase.PLACEHOLDER - x._node.meta['out'] = (x._tensor,) + x._node.meta['saved_tensor'] = [] tree_map(set_placeholder, args) tree_map(set_placeholder, kwargs) def pack(x): - if isinstance(x, FlopTensor) and not isinstance(x, torch.nn.Parameter): - x._node.meta['saved'] = True + global cache + if isinstance(x, FlopTensor) and not x._tensor.data_ptr in cache: + x._node.meta['saved_tensor'] += [x._tensor] + cache.add(x._tensor.data_ptr) return x def unpack(x): @@ -146,19 +179,23 @@ def _profile(target: Callable, *args, **kwargs) -> Tuple[Any, ...]: # If the output is not a floating point `torch.Tensor` or it does not # requires grad, then we should not run backward for this node. - if is_autogradable(out) and out.requires_grad: - phase = Phase.BACKWARD - if isinstance(out, FlopTensor): - out._node.meta['save'] = False - grad = torch.empty_like(out._tensor, device='meta') if isinstance(out, FlopTensor) else torch.empty_like( - out, device='meta') - torch.autograd.backward(out, FlopTensor(grad)) + for tensor in normalize_tuple(out): + if is_autogradable(tensor) and tensor.requires_grad: + phase = Phase.BACKWARD + grad = torch.empty_like(tensor._tensor, device=torch.device('meta')) if isinstance( + tensor, FlopTensor) else torch.empty_like(tensor, device=torch.device('meta')) + torch.autograd.backward(tensor, FlopTensor(grad, fake_device=tensor.device), retain_graph=True) graph_info = autograd_graph_analysis(subgraph) graph_info.fwd_flop, graph_info.bwd_flop = flop_count[Phase.FORWARD], flop_count[Phase.BACKWARD] + graph_info.fwd_mem_out = activation_size(out) def unwrap(x): - return x._tensor.to('meta') if isinstance(x, FlopTensor) else x + if isinstance(x, FlopTensor): + fake_device = x.device + x = x._tensor + detach(x) + return MetaTensor(x, fake_device=fake_device) if isinstance(x, torch.Tensor) else x return tree_map(unwrap, out), graph_info @@ -181,13 +218,15 @@ def profile_function(target: 'Target') -> Callable: def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: # If there is an argument that this `call_function` is inplace, we should - # skip the autograd profiling. - if kwargs.get('inplace', False): - args = tree_map(lambda x: x.to('meta') if isinstance(x, torch.Tensor) else x, args) - kwargs = tree_map(lambda x: x.to('meta') if isinstance(x, torch.Tensor) else x, kwargs) - out = func(*args, **kwargs) - return out, GraphInfo(out.numel(), out.numel(), 0, 0, 0, 0) + # still run the profiling but discard some results regarding `target` + inplace = kwargs.get('inplace', False) + if inplace: + kwargs['inplace'] = False out, meta = _profile(func, *args, **kwargs) + if inplace: + if target in [torch.nn.functional.relu]: + meta.save_fwd_in = False + meta.bwd_mem_out = 0 return out, meta f.__name__ = target.__name__ @@ -228,13 +267,17 @@ def profile_module(module: torch.nn.Module) -> Callable: def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: # If there is an argument that this `call_module` is inplace, we should - # skip the autograd profiling. - if getattr(module, 'inplace', False): - args = tree_map(lambda x: x.to('meta'), args) - kwargs = tree_map(lambda x: x.to('meta'), kwargs) - out = func(*args, **kwargs) - return out, GraphInfo(out.numel(), out.numel(), 0, 0, 0, 0) + # still run the profiling but discard some results regarding `module`. + inplace = getattr(module, 'inplace', False) + if inplace: + module.inplace = False out, meta = _profile(func, *args, **kwargs) + if inplace: + # super-dainiu: experiments on mobilenet_v2 shows that `torch.nn.ReLU` + # is the only inplace activation function that discard its input. + if type(module) in [torch.nn.ReLU]: + meta.save_fwd_in = False + meta.bwd_mem_out = 0 return out, meta f.__name__ = module.__class__.__name__ diff --git a/colossalai/fx/profiler/tensor.py b/colossalai/fx/profiler/tensor.py index 5956a1046..45d170437 100644 --- a/colossalai/fx/profiler/tensor.py +++ b/colossalai/fx/profiler/tensor.py @@ -7,6 +7,7 @@ __all__ = ['MetaTensor'] class MetaTensor(torch.Tensor): """ A wrapping tensor that hacks `torch.autograd` without patching more `torch.ops.aten` ops. + `fake_device` is the device that `MetaTensor` is supposed to run on. """ _tensor: torch.Tensor @@ -14,7 +15,7 @@ class MetaTensor(torch.Tensor): __slots__ = ['_tensor'] @staticmethod - def __new__(cls, elem): + def __new__(cls, elem, fake_device=None): # The wrapping tensor (MetaTensor) shouldn't hold any # memory for the class in question, but it should still # advertise the same device as before @@ -25,24 +26,37 @@ class MetaTensor(torch.Tensor): storage_offset=elem.storage_offset(), dtype=elem.dtype, layout=elem.layout, - device='cpu', + device=fake_device if fake_device is not None else elem.device, requires_grad=elem.requires_grad) # deceive the frontend for aten selections r._tensor = elem # ...the real tensor is held as an element on the tensor. + if not r._tensor.is_meta: + r._tensor = r._tensor.to(torch.device('meta')) + # only tensor not on `meta` should be copied to `meta` return r def __repr__(self): if self.grad_fn: - return f"MetaTensor({self._tensor}, grad_fn={self.grad_fn})" - return f"MetaTensor({self._tensor})" + return f"MetaTensor({self._tensor}, fake_device='{self.device}', grad_fn={self.grad_fn})" + return f"MetaTensor({self._tensor}, fake_device='{self.device}')" @classmethod def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + fake_device = None def unwrap(x): - if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor'): - x = MetaTensor(x) - return x._tensor.to('meta') if isinstance(x, MetaTensor) else x + nonlocal fake_device + if isinstance(x, MetaTensor): + fake_device = x.device + x = x._tensor + elif isinstance(x, torch.Tensor): + fake_device = x.device + x = x.to(torch.device('meta')) + return x + + if 'device' in kwargs: + fake_device = kwargs['device'] + kwargs['device'] = torch.device('meta') args = tree_map(unwrap, args) kwargs = tree_map(unwrap, kwargs) @@ -53,6 +67,10 @@ class MetaTensor(torch.Tensor): # Now, we want to continue propagating this tensor, so we rewrap Tensors in # our custom tensor subclass def wrap(x): - return MetaTensor(x) if isinstance(x, torch.Tensor) else x + if isinstance(x, torch.Tensor): + nonlocal fake_device + if not x.is_meta: + x = x.to(torch.device('meta')) + return MetaTensor(x, fake_device=fake_device) if isinstance(x, torch.Tensor) else x return tree_map(wrap, out) diff --git a/colossalai/fx/tracer/_meta_trace.py b/colossalai/fx/tracer/_meta_trace.py index 181a28fe9..a7f7c8159 100644 --- a/colossalai/fx/tracer/_meta_trace.py +++ b/colossalai/fx/tracer/_meta_trace.py @@ -1,10 +1,21 @@ +from colossalai.fx.profiler.memory import activation_size import torch from torch.fx import Node, Graph from torch.fx.graph import _Namespace from torch.utils._pytree import tree_map -def meta_trace(module: torch.nn.Module, *args, **kwargs) -> Graph: +def normalize_tuple(x): + if not isinstance(x, tuple): + return (x,) + return x + + +def is_autogradable(x): + return isinstance(x, torch.Tensor) and x.is_floating_point() + + +def meta_trace(module: torch.nn.Module, fake_device=None, *args, **kwargs) -> Graph: """Trace forward and backward graph with MetaTensor Args: @@ -33,7 +44,7 @@ def meta_trace(module: torch.nn.Module, *args, **kwargs) -> Graph: __slots__ = ['_tensor', '_node'] @staticmethod - def __new__(cls, tensor, placeholder=False, name=None): + def __new__(cls, tensor, fake_device=None, placeholder=False, name=None): r = torch.Tensor._make_wrapper_subclass( cls, tensor.size(), @@ -41,7 +52,7 @@ def meta_trace(module: torch.nn.Module, *args, **kwargs) -> Graph: storage_offset=tensor.storage_offset(), dtype=tensor.dtype, layout=tensor.layout, - device='cpu', + device=fake_device if fake_device is not None else tensor.device, requires_grad=tensor.requires_grad) # deceive the frontend for aten selections r._tensor = tensor if placeholder: @@ -51,15 +62,23 @@ def meta_trace(module: torch.nn.Module, *args, **kwargs) -> Graph: 'placeholder', (graph._root,), name=namespace.create_name(name, tensor)) # ...the real tensor is held as an element on the tensor. + if not r._tensor.is_meta: + r._tensor = r._tensor.to(torch.device('meta')) return r @classmethod def __torch_dispatch__(cls, func, types, args=(), kwargs=None): def unwrap(x): - if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor'): - x = MetaProxy(x) - return x._tensor.to('meta') if isinstance(x, MetaProxy) else x + nonlocal fake_device + if isinstance(x, MetaProxy): + fake_device = x.device + x = x._tensor + # assert not isinstance(x, MetaProxy) + elif isinstance(x, torch.Tensor): + fake_device = x.device + x = x.to(torch.device('meta')) + return x def get_node(x): if isinstance(x, torch.Tensor) and not hasattr(x, '_node'): @@ -70,6 +89,10 @@ def meta_trace(module: torch.nn.Module, *args, **kwargs) -> Graph: kwargs_node = tree_map(get_node, kwargs) node = graph.create_node('call_function', func, args_node, kwargs_node) + if 'device' in kwargs: + fake_device = kwargs['device'] + kwargs['device'] = torch.device('meta') + args = tree_map(unwrap, args) kwargs = tree_map(unwrap, kwargs) @@ -79,7 +102,12 @@ def meta_trace(module: torch.nn.Module, *args, **kwargs) -> Graph: # Now, we want to continue propagating this tensor, so we rewrap Tensors in # our custom tensor subclass def wrap(x): - return MetaProxy(x) if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor') else x + if isinstance(x, torch.Tensor): + nonlocal fake_device + if not x.is_meta: + x = x.to(torch.device('meta')) + return MetaProxy( + x, fake_device=fake_device) if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor') else x def set_node(x): x._node = node @@ -90,10 +118,18 @@ def meta_trace(module: torch.nn.Module, *args, **kwargs) -> Graph: return out def wrap(x): - return MetaProxy(x, True) if isinstance(x, torch.Tensor) else x + return MetaProxy(x, fake_device=fake_device, placeholder=True) if isinstance(x, torch.Tensor) else x args = tree_map(wrap, args) kwargs = tree_map(wrap, kwargs) - module(*args, **kwargs).sum().backward() + out = module(*args, **kwargs) + + for tensor in normalize_tuple(out): + if is_autogradable(tensor) and tensor.requires_grad: + grad = torch.empty_like(tensor._tensor, device=torch.device('meta')) if isinstance( + tensor, MetaProxy) else torch.empty_like(tensor, device=torch.device('meta')) + torch.autograd.backward(tensor, + MetaProxy(grad, fake_device=tensor.device, placeholder=True), + retain_graph=True) return graph diff --git a/tests/test_fx/test_comm_size_compute.py b/tests/test_fx/test_comm_size_compute.py index e4d1ff32b..bc4348c97 100644 --- a/tests/test_fx/test_comm_size_compute.py +++ b/tests/test_fx/test_comm_size_compute.py @@ -33,8 +33,9 @@ class MLP(torch.nn.Module): @pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0') def test_comm_size_compute(): + from colossalai.fx.profiler import MetaTensor model = MLP(MODEL_DIM) - input_sample = torch.rand(BATCH_SIZE, MODEL_DIM, device='meta') + input_sample = MetaTensor(torch.rand(BATCH_SIZE, MODEL_DIM, device='meta'), fake_device='cpu') gm = symbolic_trace(model) MetaInfoProp(gm).run(input_sample) annotated_model = uniform_split_pass(gm, PIPELINE_SIZE) diff --git a/tests/test_fx/test_meta/test_aten.py b/tests/test_fx/test_meta/test_aten.py index 49b978270..61eda1d67 100644 --- a/tests/test_fx/test_meta/test_aten.py +++ b/tests/test_fx/test_meta/test_aten.py @@ -62,11 +62,8 @@ def compare_all(tensor: torch.Tensor, meta_tensor: torch.Tensor) -> Any: def run_and_compare(f: Union[nn.Module, Callable], x: torch.Tensor, requires_backward=False) -> Any: x.requires_grad = requires_backward - meta_x = MetaTensor(x.to('meta')) - if isinstance(f, nn.Module): - x_out, meta_out = f(x), f.to('meta')(meta_x) - else: - x_out, meta_out = f(x), f(meta_x) + meta_x = MetaTensor(x) + x_out, meta_out = f(x), f(meta_x) compare_all(x_out, meta_out) if requires_backward: x_out.sum().backward() diff --git a/tests/test_fx/test_meta/test_backward.py b/tests/test_fx/test_meta/test_backward.py index e497792af..84ac56881 100644 --- a/tests/test_fx/test_meta/test_backward.py +++ b/tests/test_fx/test_meta/test_backward.py @@ -30,17 +30,17 @@ tmm_models = [ @pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0') def test_torchvision_models(): for m in tm_models: - model = m().to('meta') - data = torch.rand(1000, 3, 224, 224, device='meta') - model(MetaTensor(data)).sum().backward() + model = m() + data = torch.rand(100000, 3, 224, 224, device='meta') + model(MetaTensor(data, fake_device=torch.device('cpu'))).sum().backward() @pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0') def test_timm_models(): for m in tmm_models: - model = m().to('meta') - data = torch.rand(1000, 3, 224, 224, device='meta') - model(MetaTensor(data)).sum().backward() + model = m() + data = torch.rand(100000, 3, 224, 224, device='meta') + model(MetaTensor(data, fake_device=torch.device('cpu'))).sum().backward() if __name__ == '__main__': diff --git a/tests/test_fx/test_meta/test_meta_trace.py b/tests/test_fx/test_meta/test_meta_trace.py new file mode 100644 index 000000000..67b69f1da --- /dev/null +++ b/tests/test_fx/test_meta/test_meta_trace.py @@ -0,0 +1,48 @@ +import torchvision.models as tm +import timm.models as tmm +import torch +from colossalai import META_COMPATIBILITY +import pytest + +if META_COMPATIBILITY: + from colossalai.fx import meta_trace + +tm_models = [ + tm.vgg11, + tm.resnet18, + tm.densenet121, + tm.mobilenet_v3_small, + tm.resnext50_32x4d, + tm.wide_resnet50_2, + tm.regnet_x_16gf, + tm.mnasnet0_5, + tm.efficientnet_b0, +] + +tmm_models = [ + tmm.resnest.resnest50d, tmm.beit.beit_base_patch16_224, tmm.cait.cait_s24_224, tmm.efficientnet.efficientnetv2_m, + tmm.resmlp_12_224, tmm.vision_transformer.vit_base_patch16_224, tmm.deit_base_distilled_patch16_224, + tmm.convnext.convnext_base, tmm.vgg.vgg11, tmm.dpn.dpn68, tmm.densenet.densenet121, tmm.rexnet.rexnet_100, + tmm.swin_transformer.swin_base_patch4_window7_224 +] + + +@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0') +def test_torchvision_models_trace(): + for m in tm_models: + model = m() + data = torch.rand(1000, 3, 224, 224, device='meta') + graph = meta_trace(model, torch.device('cpu'), data) + + +@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0') +def test_timm_models_trace(): + for m in tmm_models: + model = m() + data = torch.rand(1000, 3, 224, 224, device='meta') + graph = meta_trace(model, torch.device('cpu'), data) + + +if __name__ == '__main__': + test_torchvision_models_trace() + test_timm_models_trace() diff --git a/tests/test_fx/test_meta_info_prop.py b/tests/test_fx/test_meta_info_prop.py index fa9067ae3..7f1051987 100644 --- a/tests/test_fx/test_meta_info_prop.py +++ b/tests/test_fx/test_meta_info_prop.py @@ -1,12 +1,8 @@ import torch -import torch.nn as nn -import colossalai -import colossalai.nn as col_nn from torch.fx import symbolic_trace +from colossalai import META_COMPATIBILITY from colossalai.fx.passes.meta_info_prop import MetaInfoProp, TensorMetadata -import pytest - BATCH_SIZE = 2 DIM_IN = 4 DIM_OUT = 16 @@ -22,6 +18,9 @@ def meta_check(meta_info_spec: TensorMetadata, orig_tensor: torch.Tensor): def test_meta_info_prop(): model = torch.nn.Linear(DIM_IN, DIM_OUT) input_sample = torch.rand(BATCH_SIZE, DIM_IN, device='meta') + if META_COMPATIBILITY: + from colossalai.fx.profiler import MetaTensor + input_sample = MetaTensor(input_sample, fake_device='cpu') orig_output = model(input_sample) gm = symbolic_trace(model) MetaInfoProp(gm).run(input_sample)