From 0584654c792fab4375c31f11a2d90e22c8a03b04 Mon Sep 17 00:00:00 2001 From: Super Daniel <78588128+super-dainiu@users.noreply.github.com> Date: Wed, 26 Oct 2022 14:24:41 +0800 Subject: [PATCH] [fx] refactor memory utils and extend shard utils. (#1754) * [fx] change memory.py to memory_utils.py. * [fx] add shard utils. * [fx] fix import. * [fx] check code style. * [fx] add comment. * [autoparallel] first move. * [fx] add time computations. --- .../fx/passes/algorithms/ckpt_solver_chen.py | 4 +- .../fx/passes/algorithms/ckpt_solver_rotor.py | 20 ++-- colossalai/fx/passes/concrete_info_prop.py | 19 ++-- colossalai/fx/passes/meta_info_prop.py | 27 ++++-- colossalai/fx/profiler/__init__.py | 10 +- colossalai/fx/profiler/dataflow.py | 10 +- .../fx/profiler/experimental/__init__.py | 2 +- .../fx/profiler/experimental/profiler.py | 16 ++-- .../{memory.py => shard_utils.py} | 0 colossalai/fx/profiler/memory_utils.py | 71 +++++++++++++++ colossalai/fx/profiler/profiler.py | 16 ++-- .../fx/profiler/{memory.py => shard_utils.py} | 91 ++++++------------- colossalai/fx/tracer/_meta_trace.py | 4 +- .../test_profiler_meta_info_prop.py | 9 +- 14 files changed, 177 insertions(+), 122 deletions(-) rename colossalai/fx/profiler/experimental/{memory.py => shard_utils.py} (100%) create mode 100644 colossalai/fx/profiler/memory_utils.py rename colossalai/fx/profiler/{memory.py => shard_utils.py} (58%) diff --git a/colossalai/fx/passes/algorithms/ckpt_solver_chen.py b/colossalai/fx/passes/algorithms/ckpt_solver_chen.py index e38ddbdce..52000ebe5 100644 --- a/colossalai/fx/passes/algorithms/ckpt_solver_chen.py +++ b/colossalai/fx/passes/algorithms/ckpt_solver_chen.py @@ -1,7 +1,9 @@ +import math from typing import List, Set, Tuple + import torch from torch.fx import GraphModule, Node -import math + from colossalai.fx.profiler import calculate_fwd_in, calculate_fwd_tmp __all__ = ['chen_greedy'] diff --git a/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py b/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py index 01c3bdb35..5b8d0da9f 100644 --- a/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py +++ b/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py @@ -1,15 +1,17 @@ +import math import sys from typing import List, Tuple -from colossalai.fx.profiler.memory import calculate_fwd_in + from torch.fx import Node -from colossalai.fx.graph_module import ColoGraphModule -from colossalai.fx.profiler import activation_size, parameter_size, calculate_fwd_out, calculate_fwd_tmp -import math -from .linearize import linearize -from .operation import ForwardCheck, ForwardEnable, ForwardNograd, Backward, Loss, Chain, Sequence, Function + from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions +from colossalai.fx.graph_module import ColoGraphModule +from colossalai.fx.profiler import activation_size, calculate_fwd_out, calculate_fwd_tmp, parameter_size from colossalai.logging import get_dist_logger +from .linearize import linearize +from .operation import Backward, Chain, ForwardCheck, ForwardEnable, ForwardNograd, Function, Loss, Sequence + # global vairable to indicate whether the solver is failed SOLVER_FAILED = False @@ -18,7 +20,7 @@ SOLVER_FAILED = False # https://gitlab.inria.fr/hiepacs/rotor # paper link: https://hal.inria.fr/hal-02352969 def _compute_table(chain: Chain, mmax) -> Tuple: - """Returns the optimal table: a tuple containing: + """Returns the optimal table: a tuple containing: Opt[m][lmin][lmax] with lmin = 0...chain.length and lmax = lmin...chain.length (lmax is not included) and m = 0...mmax what[m][lmin][lmax] is (True,) if the optimal choice is a chain checkpoint @@ -127,7 +129,7 @@ def _fwd_xbar(node: List[Node]) -> int: """Get the forward xbar of a node Args: - node (List[Node]): List of torch.fx Node, + node (List[Node]): List of torch.fx Node, indicates a node in linearized graph Returns: @@ -372,8 +374,8 @@ def solver_rotor(gm: ColoGraphModule, # build module if module not found except ModuleNotFoundError: - import subprocess import os + import subprocess logger.info("dynamic_programs_C_version hasn't been built! Building library...", ranks=[0]) this_dir = os.path.dirname(os.path.abspath(__file__)) result = subprocess.Popen( diff --git a/colossalai/fx/passes/concrete_info_prop.py b/colossalai/fx/passes/concrete_info_prop.py index 191d8d67d..ab38e8cb1 100644 --- a/colossalai/fx/passes/concrete_info_prop.py +++ b/colossalai/fx/passes/concrete_info_prop.py @@ -3,11 +3,12 @@ from typing import Any, Dict, List, NamedTuple, Optional, Tuple import torch import torch.fx -from colossalai.fx._compatibility import compatibility -from colossalai.fx.profiler import (GraphInfo, profile_function, profile_method, profile_module) from torch.fx.node import Argument, Node, Target from torch.utils._pytree import tree_flatten +from colossalai.fx._compatibility import compatibility +from colossalai.fx.profiler import GraphInfo, profile_function, profile_method, profile_module + @compatibility(is_backward_compatible=True) class ConcreteInfoProp(torch.fx.Interpreter): @@ -22,17 +23,17 @@ class ConcreteInfoProp(torch.fx.Interpreter): DIM_HIDDEN = 16 DIM_OUT = 16 model = torch.nn.Sequential( - torch.nn.Linear(DIM_IN, DIM_HIDDEN), + torch.nn.Linear(DIM_IN, DIM_HIDDEN), torch.nn.Linear(DIM_HIDDEN, DIM_OUT), ).cuda() input_sample = torch.rand(BATCH_SIZE, DIM_IN, device="cuda") gm = symbolic_trace(model) interp = ConcreteInfoProp(gm) interp.run(input_sample) - print(interp.summary(unit='kb')) - - - output of above code is + print(interp.summary(unit='kb')) + + + output of above code is Op type Op Forward time Backward time SAVE_FWD_IN FWD_OUT FWD_TMP BWD_OUT BWD_TMP ----------- ------- ----------------------- ------------------------ ------------- --------- --------- --------- --------- placeholder input_1 0.0 s 0.0 s False 0.00 KB 0.00 KB 0.00 KB 0.00 KB @@ -229,8 +230,8 @@ class ConcreteInfoProp(torch.fx.Interpreter): def summary(self, unit: str = 'MB') -> str: """ - Summarizes the memory and FLOPs statistics of the `GraphModule` in - tabular format. Note that this API requires the ``tabulate`` module + Summarizes the memory and FLOPs statistics of the `GraphModule` in + tabular format. Note that this API requires the ``tabulate`` module to be installed. """ # https://github.com/pytorch/pytorch/blob/master/torch/fx/graph.py diff --git a/colossalai/fx/passes/meta_info_prop.py b/colossalai/fx/passes/meta_info_prop.py index 4fab5d041..90009b22b 100644 --- a/colossalai/fx/passes/meta_info_prop.py +++ b/colossalai/fx/passes/meta_info_prop.py @@ -3,12 +3,21 @@ from typing import Any, Dict, List, NamedTuple, Tuple import torch import torch.fx -from colossalai.fx._compatibility import compatibility -from colossalai.fx.profiler import (GraphInfo, activation_size, calculate_fwd_in, calculate_fwd_out, calculate_fwd_tmp, - profile_function, profile_method, profile_module) from torch.fx.node import Argument, Node, Target from torch.utils._pytree import tree_map +from colossalai.fx._compatibility import compatibility +from colossalai.fx.profiler import ( + GraphInfo, + activation_size, + calculate_fwd_in, + calculate_fwd_out, + calculate_fwd_tmp, + profile_function, + profile_method, + profile_module, +) + @compatibility(is_backward_compatible=True) class TensorMetadata(NamedTuple): @@ -52,7 +61,7 @@ class MetaInfoProp(torch.fx.Interpreter): DIM_HIDDEN = 16 DIM_OUT = 16 model = torch.nn.Sequential( - torch.nn.Linear(DIM_IN, DIM_HIDDEN), + torch.nn.Linear(DIM_IN, DIM_HIDDEN), torch.nn.Linear(DIM_HIDDEN, DIM_OUT), ) input_sample = torch.rand(BATCH_SIZE, DIM_IN) @@ -60,9 +69,9 @@ class MetaInfoProp(torch.fx.Interpreter): interp = MetaInfoProp(gm) interp.run(input_sample) print(interp.summary(format='kb')) # don't panic if some statistics are 0.00 MB - - - # output of above code is + + + # output of above code is Op type Op Forward FLOPs Backward FLOPs FWD_OUT FWD_TMP BWD_OUT BWD_TMP ----------- ------- --------------- ---------------- --------- --------- --------- --------- placeholder input_1 0 FLOPs 0 FLOPs 0.00 KB 0.00 KB 0.00 KB 0.00 KB @@ -248,8 +257,8 @@ class MetaInfoProp(torch.fx.Interpreter): def summary(self, unit: str = 'MB') -> str: """ - Summarizes the memory and FLOPs statistics of the `GraphModule` in - tabular format. Note that this API requires the ``tabulate`` module + Summarizes the memory and FLOPs statistics of the `GraphModule` in + tabular format. Note that this API requires the ``tabulate`` module to be installed. """ # https://github.com/pytorch/pytorch/blob/master/torch/fx/graph.py diff --git a/colossalai/fx/profiler/__init__.py b/colossalai/fx/profiler/__init__.py index b520ff124..8bcbde0eb 100644 --- a/colossalai/fx/profiler/__init__.py +++ b/colossalai/fx/profiler/__init__.py @@ -1,12 +1,18 @@ from .._compatibility import is_compatible_with_meta if is_compatible_with_meta(): - from .memory import calculate_fwd_in, calculate_fwd_out, calculate_fwd_tmp from .opcount import flop_mapping from .profiler import profile_function, profile_method, profile_module + from .shard_utils import ( + calculate_bwd_time, + calculate_fwd_in, + calculate_fwd_out, + calculate_fwd_time, + calculate_fwd_tmp, + ) from .tensor import MetaTensor else: from .experimental import meta_profiler_function, meta_profiler_module, profile_function, profile_method, profile_module, calculate_fwd_in, calculate_fwd_tmp, calculate_fwd_out from .dataflow import GraphInfo -from .memory import activation_size, is_inplace, parameter_size +from .memory_utils import activation_size, is_inplace, parameter_size diff --git a/colossalai/fx/profiler/dataflow.py b/colossalai/fx/profiler/dataflow.py index f7009a84a..a5e888032 100644 --- a/colossalai/fx/profiler/dataflow.py +++ b/colossalai/fx/profiler/dataflow.py @@ -6,7 +6,7 @@ from typing import Dict, List from torch.fx import Graph, Node from .._compatibility import compatibility -from .memory import activation_size, is_inplace +from .memory_utils import activation_size, is_inplace class Phase(Enum): @@ -29,7 +29,7 @@ class GraphInfo: placeholders saved for | | \__________ | | backward. | | \ | | | [fwd_tmp] ------> [bwd_tmp] | <----- - | | \_________ | | [bwd_tmp] marks the peak memory + | | \_________ | | [bwd_tmp] marks the peak memory | / \ \ | | in backward pass. [x] is not counted ---> | [x] [fwd_tmp] -> [bwd_tmp] | <----- in [fwd_tmp] because | | \_____ | | @@ -80,18 +80,18 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo: Nodes should have attribute `out` indicating the output of each node. ============================================================================ Placeholder ----> p o <---- We need to keep track of grad out - |\________ | + |\________ | ↓ ↘| f --------> b |\ \_____ ↑ | \ ↘ / f f ----> b <---- Not every forward result needs to be saved for backward | \____ ↑ - ↘ ↘| + ↘ ↘| f ----> b <---- Backward can be freed as soon as it is required no more. ↘ ↗ l - ============================================================================= + ============================================================================= Args: graph (Graph): The autograd graph with nodes marked for keyword `phase`. diff --git a/colossalai/fx/profiler/experimental/__init__.py b/colossalai/fx/profiler/experimental/__init__.py index fbb6ff624..a5387981e 100644 --- a/colossalai/fx/profiler/experimental/__init__.py +++ b/colossalai/fx/profiler/experimental/__init__.py @@ -1,5 +1,5 @@ -from .memory import calculate_fwd_in, calculate_fwd_out, calculate_fwd_tmp from .profiler import profile_function, profile_method, profile_module from .profiler_function import * from .profiler_module import * from .registry import meta_profiler_function, meta_profiler_module +from .shard_utils import calculate_fwd_in, calculate_fwd_out, calculate_fwd_tmp diff --git a/colossalai/fx/profiler/experimental/profiler.py b/colossalai/fx/profiler/experimental/profiler.py index fbeea5128..5c545260e 100644 --- a/colossalai/fx/profiler/experimental/profiler.py +++ b/colossalai/fx/profiler/experimental/profiler.py @@ -5,7 +5,7 @@ import torch from torch.fx.node import Argument, Target from ..._compatibility import compatibility -from ..memory import activation_size +from ..memory_utils import activation_size from .constants import INPLACE_METHOD, INPLACE_OPS, NON_INPLACE_METHOD from .registry import meta_profiler_function, meta_profiler_module @@ -27,7 +27,7 @@ class GraphInfo: placeholders saved for | | \__________ | | backward. | | \ | | | [fwd_tmp] ------> [bwd_tmp] | <----- - | | \_________ | | [bwd_tmp] marks the peak memory + | | \_________ | | [bwd_tmp] marks the peak memory | / \ \ | | in backward pass. [x] is not counted ---> | [x] [fwd_tmp] -> [bwd_tmp] | <----- in [fwd_tmp] because | | | \_____ | | @@ -76,14 +76,14 @@ def profile_YOUR_MODULE(self: torch.nn.Module, input: torch.Tensor) -> Tuple[int @compatibility(is_backward_compatible=True) def profile_function(target: 'Target') -> Callable: """ - Wrap a `call_function` node or `torch.nn.functional` in order to + Wrap a `call_function` node or `torch.nn.functional` in order to record the memory cost and FLOPs of the execution. Unfortunately, backward memory cost and FLOPs are estimated results. - + Warnings: You may only use tensors with `device=meta` for this wrapped function. Only original `torch.nn.functional` are available. - + Examples: >>> input = torch.rand(100, 100, 100, 100, device='meta') >>> func = torch.nn.functional.relu @@ -142,13 +142,13 @@ def profile_method(target: 'Target') -> Callable: @compatibility(is_backward_compatible=True) def profile_module(module: torch.nn.Module) -> Callable: """ - Wrap a `call_module` node or `torch.nn` in order to + Wrap a `call_module` node or `torch.nn` in order to record the memory cost and FLOPs of the execution. - + Warnings: You may only use tensors with `device=meta` for this wrapped function. Only original `torch.nn` are available. - + Example: >>> input = torch.rand(4, 3, 224, 224, device='meta') >>> mod = torch.nn.Conv2d(3, 128, 3) diff --git a/colossalai/fx/profiler/experimental/memory.py b/colossalai/fx/profiler/experimental/shard_utils.py similarity index 100% rename from colossalai/fx/profiler/experimental/memory.py rename to colossalai/fx/profiler/experimental/shard_utils.py diff --git a/colossalai/fx/profiler/memory_utils.py b/colossalai/fx/profiler/memory_utils.py new file mode 100644 index 000000000..5064283b7 --- /dev/null +++ b/colossalai/fx/profiler/memory_utils.py @@ -0,0 +1,71 @@ +from typing import Dict, List, Tuple, Union + +import torch +from torch.fx import GraphModule, Node + +from .._compatibility import compatibility, is_compatible_with_meta + +__all__ = ['activation_size', 'parameter_size', 'is_inplace'] + + +@compatibility(is_backward_compatible=True) +def activation_size(out: Union[torch.Tensor, Dict, List, Tuple, int]) -> int: + """Calculate activation size of a node. + + Args: + activation (Union[torch.Tensor, Dict, List, Tuple, int]): The activation of a `torch.nn.Module` or `torch.nn.functional` + + Returns: + int: The activation size + """ + act_size = 0 + if isinstance(out, torch.Tensor): + if out.is_quantized: + act_size += out.numel() * torch._empty_affine_quantized([], dtype=out.dtype).element_size() + else: + act_size += out.numel() * torch.tensor([], dtype=out.dtype).element_size() + elif isinstance(out, dict): + value_list = [v for _, v in out.items()] + act_size += activation_size(value_list) + elif isinstance(out, tuple) or isinstance(out, list) or isinstance(out, set): + for element in out: + act_size += activation_size(element) + return act_size + + +@compatibility(is_backward_compatible=True) +def parameter_size(mod: torch.nn.Module) -> int: + """Calculate parameter size of a node. + + Args: + mod (torch.nn.Module): The target `torch.nn.Module` + + Returns: + int: The parameter size + """ + param_size = 0 + for param in mod.parameters(): + param_size += param.numel() * torch.tensor([], dtype=param.dtype).element_size() + return param_size + + +def is_inplace(n: Node): + """Get the inplace argument from torch.fx.Node + + Args: + node (Node): torch.fx.Node + + Returns: + bool: indicates whether this op is inplace + """ + inplace = False + if n.op == "call_function": + inplace = n.kwargs.get("inplace", False) + if is_compatible_with_meta(): + from .constants import ALIAS_ATEN + if n.target in ALIAS_ATEN: + inplace = True + elif n.op == "call_module": + inplace = getattr(n.graph.owning_module.get_submodule(n.target), "inplace", False) + + return inplace diff --git a/colossalai/fx/profiler/profiler.py b/colossalai/fx/profiler/profiler.py index 2fa5c41c0..fbffb23d2 100644 --- a/colossalai/fx/profiler/profiler.py +++ b/colossalai/fx/profiler/profiler.py @@ -11,7 +11,7 @@ from torch.utils._pytree import tree_map from .._compatibility import compatibility from .constants import ALIAS_ATEN, OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS from .dataflow import GraphInfo, Phase, autograd_graph_analysis, is_phase -from .memory import activation_size, parameter_size +from .memory_utils import activation_size, parameter_size from .opcount import flop_mapping from .tensor import MetaTensor @@ -286,13 +286,13 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G @compatibility(is_backward_compatible=True) def profile_function(target: 'Target', device: str = 'meta') -> Callable: """ - Wrap a `call_function` node or `torch.nn.functional` in order to + Wrap a `call_function` node or `torch.nn.functional` in order to record the memory cost and FLOPs of the execution. - + Warnings: You may only use tensors with `device=meta` for this wrapped function. Only original `torch.nn.functional` are available. - + Examples: >>> input = torch.rand(100, 100, 100, 100, device='meta') >>> func = torch.nn.functional.relu @@ -342,7 +342,7 @@ def profile_function(target: 'Target', device: str = 'meta') -> Callable: def profile_method(target: 'Target', device: str = 'meta') -> Callable: """ Wrap a `call_method` node - record the memory cost and FLOPs of the execution. + record the memory cost and FLOPs of the execution. """ def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: @@ -360,13 +360,13 @@ def profile_method(target: 'Target', device: str = 'meta') -> Callable: @compatibility(is_backward_compatible=True) def profile_module(module: torch.nn.Module, device: str = 'meta') -> Callable: """ - Wrap a `call_module` node or `torch.nn` in order to + Wrap a `call_module` node or `torch.nn` in order to record the memory cost and FLOPs of the execution. - + Warnings: You may only use tensors with `device=meta` for this wrapped function. Only original `torch.nn` are available. - + Example: >>> input = torch.rand(4, 3, 224, 224, device='meta') >>> mod = torch.nn.Conv2d(3, 128, 3) diff --git a/colossalai/fx/profiler/memory.py b/colossalai/fx/profiler/shard_utils.py similarity index 58% rename from colossalai/fx/profiler/memory.py rename to colossalai/fx/profiler/shard_utils.py index 2e8b5d51b..3ba0cb68e 100644 --- a/colossalai/fx/profiler/memory.py +++ b/colossalai/fx/profiler/shard_utils.py @@ -1,58 +1,18 @@ -from typing import Dict, List, Tuple, Union - import torch -from torch.fx import GraphModule, Node +from torch.fx import Node from .._compatibility import compatibility, is_compatible_with_meta +from .memory_utils import activation_size if is_compatible_with_meta(): from .constants import OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS -__all__ = [ - 'activation_size', 'parameter_size', 'is_inplace', "calculate_fwd_in", "calculate_fwd_tmp", "calculate_fwd_out" -] - - -@compatibility(is_backward_compatible=True) -def activation_size(out: Union[torch.Tensor, Dict, List, Tuple, int]) -> int: - """Calculate activation size of a node. - - Args: - activation (Union[torch.Tensor, Dict, List, Tuple, int]): The activation of a `torch.nn.Module` or `torch.nn.functional` - - Returns: - int: The activation size - """ - act_size = 0 - if isinstance(out, torch.Tensor): - act_size += out.numel() * torch.tensor([], dtype=out.dtype).element_size() - elif isinstance(out, dict): - value_list = [v for _, v in out.items()] - act_size += activation_size(value_list) - elif isinstance(out, tuple) or isinstance(out, list) or isinstance(out, set): - for element in out: - act_size += activation_size(element) - return act_size - - -@compatibility(is_backward_compatible=True) -def parameter_size(mod: torch.nn.Module) -> int: - """Calculate parameter size of a node. - - Args: - mod (torch.nn.Module): The target `torch.nn.Module` - - Returns: - int: The parameter size - """ - param_size = 0 - for param in mod.parameters(): - param_size += param.numel() * torch.tensor([], dtype=param.dtype).element_size() - return param_size +__all__ = ["calculate_fwd_in", "calculate_fwd_tmp", "calculate_fwd_out"] +@compatibility(is_backward_compatible=False) def calculate_fwd_in(n: Node) -> int: - """A helper function to calculate `fwd_in` + """A helper function to calculate `fwd_in` (with sharding spec) Args: n (Node): a node from the graph @@ -60,11 +20,13 @@ def calculate_fwd_in(n: Node) -> int: Returns: fwd_in (int): the result of `fwd_in` """ + # TODO(super-dainiu): should divide the memory by sharding spec return activation_size(n.meta["fwd_in"]) +@compatibility(is_backward_compatible=False) def calculate_fwd_tmp(n: Node) -> int: - """A helper function to calculate `fwd_tmp` + """A helper function to calculate `fwd_tmp` (with sharding spec) Currently, `torch.nn.ReLU` behaves weirdly, so we have to patch it for accuracy. Args: @@ -74,6 +36,7 @@ def calculate_fwd_tmp(n: Node) -> int: fwd_tmp (int): the result of `fwd_tmp` """ + # TODO(super-dainiu): should divide the memory by sharding spec def is_relu_like_node(n: Node) -> bool: """Check if a node is a ReLU-like node. ReLU-like nodes have the following properties: @@ -107,8 +70,9 @@ def calculate_fwd_tmp(n: Node) -> int: return 0 +@compatibility(is_backward_compatible=False) def calculate_fwd_out(n: Node) -> int: - """A helper function to calculate `fwd_out` + """A helper function to calculate `fwd_out` (with sharding spec) Args: n (Node): a node from the graph @@ -117,6 +81,7 @@ def calculate_fwd_out(n: Node) -> int: fwd_out (int): the result of `fwd_out` """ + # TODO(super-dainiu): should divide the memory by sharding spec def intersect(a, b): return {k: a[k] for k in a if k in b} @@ -127,23 +92,23 @@ def calculate_fwd_out(n: Node) -> int: return activation_size(intersect(fwd_in, fwd_out)) -def is_inplace(n: Node): - """Get the inplace argument from torch.fx.Node - +def calculate_fwd_time(n: Node) -> float: + """A helper function to calculate `fwd_time` (with sharding spec) Args: - node (Node): torch.fx.Node - + n (Node): a node from the graph Returns: - bool: indicates whether this op is inplace + fwd_time (float): the result of `fwd_time` """ - inplace = False - if n.op == "call_function": - inplace = n.kwargs.get("inplace", False) - if is_compatible_with_meta(): - from .constants import ALIAS_ATEN - if n.target in ALIAS_ATEN: - inplace = True - elif n.op == "call_module": - inplace = getattr(n.graph.owning_module.get_submodule(n.target), "inplace", False) + # TODO(super-dainiu): should divide the time by the number of GPUs as well as TFLOPs + return n.meta["fwd_flop"] - return inplace + +def calculate_bwd_time(n: Node) -> float: + """A helper function to calculate `bwd_time` (with sharding spec) + Args: + n (Node): a node from the graph + Returns: + bwd_time (float): the result of `bwd_time` + """ + # TODO(super-dainiu): should divide the time by the number of GPUs as well as TFLOPs + return n.meta["bwd_flop"] diff --git a/colossalai/fx/tracer/_meta_trace.py b/colossalai/fx/tracer/_meta_trace.py index a7f7c8159..1c5abb81d 100644 --- a/colossalai/fx/tracer/_meta_trace.py +++ b/colossalai/fx/tracer/_meta_trace.py @@ -1,7 +1,5 @@ -from colossalai.fx.profiler.memory import activation_size import torch -from torch.fx import Node, Graph -from torch.fx.graph import _Namespace +from torch.fx import Graph, Node from torch.utils._pytree import tree_map diff --git a/tests/test_fx/test_profiler/test_profiler_meta_info_prop.py b/tests/test_fx/test_profiler/test_profiler_meta_info_prop.py index a9921af3c..c71796018 100644 --- a/tests/test_fx/test_profiler/test_profiler_meta_info_prop.py +++ b/tests/test_fx/test_profiler/test_profiler_meta_info_prop.py @@ -3,13 +3,14 @@ from typing import Optional, Tuple, Union import torch import torch.fx import torchvision.models as tm -from colossalai.fx.passes.meta_info_prop import MetaInfoProp -from colossalai.fx.profiler import (calculate_fwd_out, calculate_fwd_tmp, is_compatible_with_meta, parameter_size) -from colossalai.fx.tracer.tracer import ColoTracer -from colossalai.testing.pytest_wrapper import run_on_environment_flag from gpt_utils import gpt2_medium, gpt2_xl from torch.fx import symbolic_trace +from colossalai.fx.passes.meta_info_prop import MetaInfoProp +from colossalai.fx.profiler import calculate_fwd_out, calculate_fwd_tmp, is_compatible_with_meta, parameter_size +from colossalai.fx.tracer.tracer import ColoTracer +from colossalai.testing.pytest_wrapper import run_on_environment_flag + if is_compatible_with_meta(): from colossalai.fx.profiler import MetaTensor