mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 09:07:51 +00:00
[fx] provide a stable but not accurate enough version of profiler. (#1547)
* [fx] compute memory stat and flop count for MetaInfoProp. * [fx] modify node attribute. * [fx] modify ckpt_chen. * [fx] fix compatibility. * [fx] fix import error. * [fx] skip test for MetaInfoProp. * [fx] skip test for MetaInfoProp. * [fx] skip test for MetaInfoProp. * [fx] skip test for MetaInfoProp. * [fx] skip if torch 1.11.0. * [fx] recover MetaInfoProp support for PyTorch 1.11. * [fx] provide a stable but not accurate enough version of profiler. * [fx] provide a stable but not accurate enough version of profiler. * [fx] fix compatibility in tests. * [fx] fix compatibility in tests. * [fx] fix compatibility in tests. * [fx] fix compatibility in tests. * [fx] fix compatibility in tests. * [fx] fix compatibility in tests. * [fx] fix compatibility in tests. * [fx] fix compatibility in tests. * [fx] fix compatibility in tests. * [fx] fix compatibility in tests. * [fx] fix import error.
This commit is contained in:
@@ -1,13 +1,10 @@
|
||||
from operator import add, getitem
|
||||
import torch
|
||||
import torch.fx
|
||||
from torch.fx.node import Node, Argument, Target
|
||||
from torch.utils._pytree import tree_map
|
||||
from typing import Any, Tuple, NamedTuple, Optional, Dict
|
||||
from functools import reduce
|
||||
from typing import Any, Tuple, NamedTuple, Dict
|
||||
from torch.fx._compatibility import compatibility
|
||||
from torch.fx.immutable_collections import immutable_dict, immutable_list
|
||||
from colossalai.fx.profiler import MetaProfile, MetaTensor, profile_function, profile_module, calculate_activation_size, profile_method
|
||||
from colossalai.fx.profiler import profile_function, profile_module, profile_method, activation_size, parameter_size
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
@@ -71,14 +68,6 @@ class MetaInfoProp(torch.fx.Interpreter):
|
||||
|
||||
"""
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def run(self, *args, initial_env: Optional[Dict[Node, Any]] = None, enable_io_processing: bool = True) -> Any:
|
||||
"""
|
||||
Add additional check for initial args to ensure all the tensor appears with `device='meta'`
|
||||
"""
|
||||
args = tree_map(lambda elem: MetaTensor(elem.to('meta')) if isinstance(elem, torch.Tensor) else elem, args)
|
||||
return super().run(*args, initial_env, enable_io_processing)
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def run_node(self, n: Node) -> Any:
|
||||
"""
|
||||
@@ -93,8 +82,7 @@ class MetaInfoProp(torch.fx.Interpreter):
|
||||
Returns:
|
||||
Any: The result of executing ``n``
|
||||
"""
|
||||
result, profile = super().run_node(n)
|
||||
profile: MetaProfile
|
||||
result, flop_count, mem_stat = super().run_node(n)
|
||||
|
||||
def extract_tensor_meta(obj):
|
||||
if isinstance(obj, torch.Tensor):
|
||||
@@ -106,12 +94,17 @@ class MetaInfoProp(torch.fx.Interpreter):
|
||||
n.meta['tensor_meta'] = meta
|
||||
|
||||
# TODO: the attribute node_size should be removed in the future
|
||||
setattr(n, 'node_size', profile.param + profile.activation)
|
||||
setattr(n, '__param__', profile.param)
|
||||
setattr(n, '__activation__', profile.activation)
|
||||
setattr(n, '__flops__', profile.flops)
|
||||
setattr(n, '__macs__', profile.macs)
|
||||
setattr(n, 'node_size', mem_stat[1])
|
||||
setattr(n, 'fwd_flop', flop_count[0])
|
||||
setattr(n, 'bwd_flop', flop_count[1])
|
||||
setattr(n, 'fwd_tmp', mem_stat[0])
|
||||
setattr(n, 'fwd_out', mem_stat[1])
|
||||
setattr(n, 'bwd_tmp', mem_stat[2])
|
||||
setattr(n, 'bwd_out', mem_stat[3])
|
||||
n.meta['type'] = type(result)
|
||||
|
||||
for param in self.module.parameters():
|
||||
param.grad = None
|
||||
return result
|
||||
|
||||
# Main Node running APIs
|
||||
@@ -132,11 +125,12 @@ class MetaInfoProp(torch.fx.Interpreter):
|
||||
|
||||
Returns:
|
||||
result (Any): The argument value that was retrieved
|
||||
profile (MetaProfile): The meta profile of this node
|
||||
flop_count (Tuple): The flop count for (fwd_flop, bwd_flop).
|
||||
mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out)
|
||||
"""
|
||||
result = super().placeholder(target, args, kwargs)
|
||||
# A placeholder node only has activation
|
||||
return result, MetaProfile(0, calculate_activation_size(result), 0, 0)
|
||||
return result, (0, 0), (0, activation_size(result), 0, 0)
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def get_attr(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
@@ -153,10 +147,10 @@ class MetaInfoProp(torch.fx.Interpreter):
|
||||
|
||||
Return:
|
||||
result (Any): The argument value that was retrieved
|
||||
profile (MetaProfile): The meta profile of this node
|
||||
flop_count (Tuple): The flop count for (fwd_flop, bwd_flop).
|
||||
mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out)
|
||||
"""
|
||||
# A get_attr node never has parameters, activations, FLOPs, or MACs
|
||||
return super().get_attr(target, args, kwargs), MetaProfile(0, 0, 0, 0)
|
||||
return super().get_attr(target, args, kwargs), (0, 0), (0, 0, 0, 0)
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
@@ -172,7 +166,8 @@ class MetaInfoProp(torch.fx.Interpreter):
|
||||
|
||||
Return
|
||||
result (Any): The argument value that was retrieved
|
||||
profile (MetaProfile): The meta profile of this node
|
||||
flop_count (Tuple): The flop count for (fwd_flop, bwd_flop).
|
||||
mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out)
|
||||
"""
|
||||
assert not isinstance(target, str)
|
||||
return profile_function(target)(*args, **kwargs)
|
||||
@@ -191,7 +186,8 @@ class MetaInfoProp(torch.fx.Interpreter):
|
||||
|
||||
Return
|
||||
result (Any): The argument value that was retrieved
|
||||
profile (MetaProfile): The meta profile of this node
|
||||
flop_count (Tuple): The flop count for (fwd_flop, bwd_flop).
|
||||
mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out)
|
||||
"""
|
||||
return profile_method(target)(*args, **kwargs)
|
||||
|
||||
@@ -209,7 +205,8 @@ class MetaInfoProp(torch.fx.Interpreter):
|
||||
|
||||
Return
|
||||
result (Any): The argument value that was retrieved
|
||||
profile (MetaProfile): The meta profile of this node
|
||||
flop_count (Tuple): The flop count for (fwd_flop, bwd_flop).
|
||||
mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out)
|
||||
"""
|
||||
# Retrieve executed args and kwargs values from the environment
|
||||
# Execute the method and return the result
|
||||
@@ -231,9 +228,11 @@ class MetaInfoProp(torch.fx.Interpreter):
|
||||
kwargs (Dict): Dict of keyword arguments for this invocation
|
||||
|
||||
Return:
|
||||
Any: The return value referenced by the output node
|
||||
result (Any): The argument value that was retrieved
|
||||
flop_count (Tuple): The flop count for (fwd_flop, bwd_flop).
|
||||
mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out)
|
||||
"""
|
||||
return args[0], MetaProfile(0, 0, 0, 0)
|
||||
return args[0], (0, 0), (0, 0, 0, 0)
|
||||
|
||||
def propagate(self, *args):
|
||||
"""
|
||||
|
Reference in New Issue
Block a user