diff --git a/colossalai/fx/__init__.py b/colossalai/fx/__init__.py index 6d0475f70..b1850798c 100644 --- a/colossalai/fx/__init__.py +++ b/colossalai/fx/__init__.py @@ -1,2 +1,3 @@ from .tracer import ColoTracer, meta_trace from .graph_module import ColoGraphModule +from .passes import MetaInfoProp diff --git a/colossalai/fx/passes/__init__.py b/colossalai/fx/passes/__init__.py index b1e95b876..aa6a7009c 100644 --- a/colossalai/fx/passes/__init__.py +++ b/colossalai/fx/passes/__init__.py @@ -1,2 +1,3 @@ from .adding_split_node_pass import balanced_split_pass, split_with_split_nodes_pass -from .shard_1d_pass import column_shard_linear_pass, row_shard_linear_pass \ No newline at end of file +from .shard_1d_pass import column_shard_linear_pass, row_shard_linear_pass +from .meta_info_prop import MetaInfoProp diff --git a/colossalai/fx/passes/meta_info_prop.py b/colossalai/fx/passes/meta_info_prop.py index 170176d71..403819a29 100644 --- a/colossalai/fx/passes/meta_info_prop.py +++ b/colossalai/fx/passes/meta_info_prop.py @@ -4,7 +4,7 @@ import torch import torch.fx from torch.fx.node import Node, Argument, Target from torch.utils._pytree import tree_map -from typing import Any, Tuple, NamedTuple, Dict +from typing import Any, List, Tuple, NamedTuple, Dict from torch.fx._compatibility import compatibility from colossalai.fx.profiler import profile_function, profile_module, profile_method, activation_size @@ -48,28 +48,33 @@ class MetaInfoProp(torch.fx.Interpreter): Usage: BATCH_SIZE = 2 DIM_IN = 4 + DIM_HIDDEN = 16 DIM_OUT = 16 - model = torch.nn.Linear(DIM_IN, DIM_OUT) + model = torch.nn.Sequential( + torch.nn.Linear(DIM_IN, DIM_HIDDEN), + torch.nn.Linear(DIM_HIDDEN, DIM_OUT), + ) input_sample = torch.rand(BATCH_SIZE, DIM_IN) - orig_output = model(input_sample) gm = symbolic_trace(model) - MetaInfoProp(gm).run(input_sample) - - for node in gm.graph.nodes: - print(node.name, node.meta['tensor_meta'].dtype, - node.meta['tensor_meta'].shape, node.meta['tensor_meta'].numel) + interp = MetaInfoProp(gm) + interp.run(input_sample) + print(interp.summary(format='kb')) # don't panic if some statistics are 0.00 MB + # output of above code is - # input_1 torch.float32 torch.Size([2, 4]) 8 - # weight torch.float32 torch.Size([16, 4]) 64 - # bias torch.float32 torch.Size([16]) 16 - # linear torch.float32 torch.Size([2, 16]) 32 - # output torch.float32 torch.Size([2, 16]) 32 + Op type Op Forward FLOPs Backward FLOPs SAVE_FWD_IN FWD_OUT FWD_TMP BWD_OUT BWD_TMP + ----------- ------- --------------- ---------------- ------------- --------- --------- --------- --------- + placeholder input_1 0 FLOPs 0 FLOPs False 0.00 KB 0.00 KB 0.00 KB 0.00 KB + call_module _0 128 FLOPs 288 FLOPs True 0.12 KB 0.00 KB 0.34 KB 0.00 KB + call_module _1 512 FLOPs 1,056 FLOPs True 0.12 KB 0.00 KB 1.19 KB 0.00 KB + output output 0 FLOPs 0 FLOPs True 0.00 KB 0.00 KB 0.00 KB 0.00 KB Args: module (GraphModule): The module to be executed """ + _is_proped: bool = False + @compatibility(is_backward_compatible=True) def run_node(self, n: Node) -> Any: """ @@ -84,6 +89,7 @@ class MetaInfoProp(torch.fx.Interpreter): Returns: Any: The result of executing ``n`` """ + self._is_proped = True result, meta_info = super().run_node(n) def extract_tensor_meta(obj): @@ -236,3 +242,64 @@ class MetaInfoProp(torch.fx.Interpreter): Any: The value returned from executing the Module """ return super().run(*args) + + def summary(self, unit: str = 'MB') -> str: + """ + Summarizes the memory and FLOPs statistics of the `GraphModule` in + tabular format. Note that this API requires the ``tabulate`` module + to be installed. + """ + # https://github.com/pytorch/pytorch/blob/master/torch/fx/graph.py + try: + from tabulate import tabulate + except ImportError: + print("`summary` relies on the library `tabulate`, " + "which could not be found on this machine. Run `pip " + "install tabulate` to install the library.") + + assert self._is_proped, "Please call `interp.run(input)` before calling `interp.summary()`." + + # Build up a list of summary information for each node + node_summaries: List[List[Any]] = [] + + def mem_repr(mem: int) -> str: + unit_divisor_map = { + 'kb': 1024, + 'mb': 1024**2, + 'gb': 1024**3, + 'tb': 1024**4, + } + return f"{mem / unit_divisor_map[unit.lower()]:.2f} {unit.upper()}" + + def flops_repr(flop: int) -> str: + return f"{flop:,} FLOPs" + + for node in self.module.graph.nodes: + node: Node + node_summaries.append([ + node.op, + str(node), + flops_repr(node.meta['fwd_flop']), + flops_repr(node.meta['bwd_flop']), + node.meta['save_fwd_in'], + mem_repr(node.meta['fwd_mem_out']), + mem_repr(node.meta['fwd_mem_tmp']), + mem_repr(node.meta['bwd_mem_out']), + mem_repr(node.meta['bwd_mem_tmp']), + ]) + + # Use the ``tabulate`` library to create a well-formatted table + # presenting our summary information + headers: List[str] = [ + 'Op type', + 'Op', + 'Forward FLOPs', + 'Backward FLOPs', + 'SAVE_FWD_IN', + 'FWD_OUT', + 'FWD_TMP', + 'BWD_OUT', + 'BWD_TMP', + ] + + return tabulate(node_summaries, headers=headers, stralign='right')