[fx] provide a stable but not accurate enough version of profiler. (#1547)

* [fx] compute memory stat and flop count for MetaInfoProp.

* [fx] modify node attribute.

* [fx] modify ckpt_chen.

* [fx] fix compatibility.

* [fx] fix import error.

* [fx] skip test for MetaInfoProp.

* [fx] skip test for MetaInfoProp.

* [fx] skip test for MetaInfoProp.

* [fx] skip test for MetaInfoProp.

* [fx] skip if torch 1.11.0.

* [fx] recover MetaInfoProp support for PyTorch 1.11.

* [fx] provide a stable but not accurate enough version of profiler.

* [fx] provide a stable but not accurate enough version of profiler.

* [fx] fix compatibility in tests.

* [fx] fix compatibility in tests.

* [fx] fix compatibility in tests.

* [fx] fix compatibility in tests.

* [fx] fix compatibility in tests.

* [fx] fix compatibility in tests.

* [fx] fix compatibility in tests.

* [fx] fix compatibility in tests.

* [fx] fix compatibility in tests.

* [fx] fix compatibility in tests.

* [fx] fix import error.
This commit is contained in:
Super Daniel
2022-09-07 11:21:04 +08:00
committed by GitHub
parent 7d49e7b2db
commit 4f59693207
38 changed files with 776 additions and 263 deletions

View File

@@ -1,13 +1,10 @@
from operator import add, getitem
import torch
import torch.fx
from torch.fx.node import Node, Argument, Target
from torch.utils._pytree import tree_map
from typing import Any, Tuple, NamedTuple, Optional, Dict
from functools import reduce
from typing import Any, Tuple, NamedTuple, Dict
from torch.fx._compatibility import compatibility
from torch.fx.immutable_collections import immutable_dict, immutable_list
from colossalai.fx.profiler import MetaProfile, MetaTensor, profile_function, profile_module, calculate_activation_size, profile_method
from colossalai.fx.profiler import profile_function, profile_module, profile_method, activation_size, parameter_size
@compatibility(is_backward_compatible=True)
@@ -71,14 +68,6 @@ class MetaInfoProp(torch.fx.Interpreter):
"""
@compatibility(is_backward_compatible=True)
def run(self, *args, initial_env: Optional[Dict[Node, Any]] = None, enable_io_processing: bool = True) -> Any:
"""
Add additional check for initial args to ensure all the tensor appears with `device='meta'`
"""
args = tree_map(lambda elem: MetaTensor(elem.to('meta')) if isinstance(elem, torch.Tensor) else elem, args)
return super().run(*args, initial_env, enable_io_processing)
@compatibility(is_backward_compatible=True)
def run_node(self, n: Node) -> Any:
"""
@@ -93,8 +82,7 @@ class MetaInfoProp(torch.fx.Interpreter):
Returns:
Any: The result of executing ``n``
"""
result, profile = super().run_node(n)
profile: MetaProfile
result, flop_count, mem_stat = super().run_node(n)
def extract_tensor_meta(obj):
if isinstance(obj, torch.Tensor):
@@ -106,12 +94,17 @@ class MetaInfoProp(torch.fx.Interpreter):
n.meta['tensor_meta'] = meta
# TODO: the attribute node_size should be removed in the future
setattr(n, 'node_size', profile.param + profile.activation)
setattr(n, '__param__', profile.param)
setattr(n, '__activation__', profile.activation)
setattr(n, '__flops__', profile.flops)
setattr(n, '__macs__', profile.macs)
setattr(n, 'node_size', mem_stat[1])
setattr(n, 'fwd_flop', flop_count[0])
setattr(n, 'bwd_flop', flop_count[1])
setattr(n, 'fwd_tmp', mem_stat[0])
setattr(n, 'fwd_out', mem_stat[1])
setattr(n, 'bwd_tmp', mem_stat[2])
setattr(n, 'bwd_out', mem_stat[3])
n.meta['type'] = type(result)
for param in self.module.parameters():
param.grad = None
return result
# Main Node running APIs
@@ -132,11 +125,12 @@ class MetaInfoProp(torch.fx.Interpreter):
Returns:
result (Any): The argument value that was retrieved
profile (MetaProfile): The meta profile of this node
flop_count (Tuple): The flop count for (fwd_flop, bwd_flop).
mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out)
"""
result = super().placeholder(target, args, kwargs)
# A placeholder node only has activation
return result, MetaProfile(0, calculate_activation_size(result), 0, 0)
return result, (0, 0), (0, activation_size(result), 0, 0)
@compatibility(is_backward_compatible=True)
def get_attr(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
@@ -153,10 +147,10 @@ class MetaInfoProp(torch.fx.Interpreter):
Return:
result (Any): The argument value that was retrieved
profile (MetaProfile): The meta profile of this node
flop_count (Tuple): The flop count for (fwd_flop, bwd_flop).
mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out)
"""
# A get_attr node never has parameters, activations, FLOPs, or MACs
return super().get_attr(target, args, kwargs), MetaProfile(0, 0, 0, 0)
return super().get_attr(target, args, kwargs), (0, 0), (0, 0, 0, 0)
@compatibility(is_backward_compatible=True)
def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
@@ -172,7 +166,8 @@ class MetaInfoProp(torch.fx.Interpreter):
Return
result (Any): The argument value that was retrieved
profile (MetaProfile): The meta profile of this node
flop_count (Tuple): The flop count for (fwd_flop, bwd_flop).
mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out)
"""
assert not isinstance(target, str)
return profile_function(target)(*args, **kwargs)
@@ -191,7 +186,8 @@ class MetaInfoProp(torch.fx.Interpreter):
Return
result (Any): The argument value that was retrieved
profile (MetaProfile): The meta profile of this node
flop_count (Tuple): The flop count for (fwd_flop, bwd_flop).
mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out)
"""
return profile_method(target)(*args, **kwargs)
@@ -209,7 +205,8 @@ class MetaInfoProp(torch.fx.Interpreter):
Return
result (Any): The argument value that was retrieved
profile (MetaProfile): The meta profile of this node
flop_count (Tuple): The flop count for (fwd_flop, bwd_flop).
mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out)
"""
# Retrieve executed args and kwargs values from the environment
# Execute the method and return the result
@@ -231,9 +228,11 @@ class MetaInfoProp(torch.fx.Interpreter):
kwargs (Dict): Dict of keyword arguments for this invocation
Return:
Any: The return value referenced by the output node
result (Any): The argument value that was retrieved
flop_count (Tuple): The flop count for (fwd_flop, bwd_flop).
mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out)
"""
return args[0], MetaProfile(0, 0, 0, 0)
return args[0], (0, 0), (0, 0, 0, 0)
def propagate(self, *args):
"""