[fx] provide an accurate estimation of memory. (#1587)

* [fx] add some comment and docstrings.

* [fx] add dataflow analysis for an autograd graph.

* add intepretation for graph analysis.

* [fx] before doing save_tensor_hooks.

* [fx] provide an accurate estimation of memory except for GPT-2.

* [fx] provide an accurate estimation of memory except for GPT-2.

* [fx] provide an accurate estimation of memory except for GPT-2.

* [fx] a very accurate version on GPT-2.

* [fx] refactor code.

* [fx] remove redundant inplace=True.

* [fx] refactor code.

* [fx] refactor code.

* [fx] refactor code.

* [fx] dive into backward memory.
This commit is contained in:
Super Daniel
2022-09-14 09:36:43 +08:00
committed by GitHub
parent 27fe8af60c
commit 5c494d4540
6 changed files with 301 additions and 95 deletions

View File

@@ -1,10 +1,12 @@
from dataclasses import asdict
from colossalai.fx.profiler import GraphInfo
import torch
import torch.fx
from torch.fx.node import Node, Argument, Target
from torch.utils._pytree import tree_map
from typing import Any, Tuple, NamedTuple, Dict
from torch.fx._compatibility import compatibility
from colossalai.fx.profiler import profile_function, profile_module, profile_method, activation_size, parameter_size
from colossalai.fx.profiler import profile_function, profile_module, profile_method, activation_size
@compatibility(is_backward_compatible=True)
@@ -40,7 +42,7 @@ def _extract_tensor_metadata(result: torch.Tensor) -> TensorMetadata:
class MetaInfoProp(torch.fx.Interpreter):
"""
Execute an FX graph Node-by-Node with meta tensor and
record the shape, FLOPs, MACs and type of the result
record the memory usage, FLOPs, and type of the result
into the corresponding node.
Usage:
@@ -82,7 +84,7 @@ class MetaInfoProp(torch.fx.Interpreter):
Returns:
Any: The result of executing ``n``
"""
result, flop_count, mem_stat = super().run_node(n)
result, meta_info = super().run_node(n)
def extract_tensor_meta(obj):
if isinstance(obj, torch.Tensor):
@@ -90,21 +92,20 @@ class MetaInfoProp(torch.fx.Interpreter):
else:
return TensorMetadata(None, None, False, None, 0, False)
meta = tree_map(extract_tensor_meta, result)
n.meta['tensor_meta'] = meta
tensor_meta = tree_map(extract_tensor_meta, result)
n.meta['tensor_meta'] = tensor_meta
n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta`
# TODO: the attribute node_size should be removed in the future
setattr(n, 'node_size', mem_stat[1])
setattr(n, 'fwd_flop', flop_count[0])
setattr(n, 'bwd_flop', flop_count[1])
setattr(n, 'fwd_tmp', mem_stat[0])
setattr(n, 'fwd_out', mem_stat[1])
setattr(n, 'bwd_tmp', mem_stat[2])
setattr(n, 'bwd_out', mem_stat[3])
setattr(n, 'node_size', n.meta.get('fwd_mem_tmp', 0) + n.meta.get('fwd_mem_out', 0))
for par in n.all_input_nodes:
par.meta['fwd_mem_out'] = par.meta.get('fwd_mem_out', 0) + n.meta.get('fwd_mem_in', 0)
n.meta['type'] = type(result)
# retain the autograd graph
for param in self.module.parameters():
param.grad = None
return result
# Main Node running APIs
@@ -125,12 +126,9 @@ class MetaInfoProp(torch.fx.Interpreter):
Returns:
result (Any): The argument value that was retrieved
flop_count (Tuple): The flop count for (fwd_flop, bwd_flop).
mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out)
meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`.
"""
result = super().placeholder(target, args, kwargs)
# A placeholder node only has activation
return result, (0, 0), (0, activation_size(result), 0, 0)
return super().placeholder(target, args, kwargs), GraphInfo()
@compatibility(is_backward_compatible=True)
def get_attr(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
@@ -147,10 +145,9 @@ class MetaInfoProp(torch.fx.Interpreter):
Return:
result (Any): The argument value that was retrieved
flop_count (Tuple): The flop count for (fwd_flop, bwd_flop).
mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out)
meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`.
"""
return super().get_attr(target, args, kwargs), (0, 0), (0, 0, 0, 0)
return super().get_attr(target, args, kwargs), GraphInfo()
@compatibility(is_backward_compatible=True)
def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
@@ -166,8 +163,7 @@ class MetaInfoProp(torch.fx.Interpreter):
Return
result (Any): The argument value that was retrieved
flop_count (Tuple): The flop count for (fwd_flop, bwd_flop).
mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out)
meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`.
"""
assert not isinstance(target, str)
return profile_function(target)(*args, **kwargs)
@@ -186,8 +182,7 @@ class MetaInfoProp(torch.fx.Interpreter):
Return
result (Any): The argument value that was retrieved
flop_count (Tuple): The flop count for (fwd_flop, bwd_flop).
mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out)
meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`.
"""
return profile_method(target)(*args, **kwargs)
@@ -205,8 +200,7 @@ class MetaInfoProp(torch.fx.Interpreter):
Return
result (Any): The argument value that was retrieved
flop_count (Tuple): The flop count for (fwd_flop, bwd_flop).
mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out)
meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`.
"""
# Retrieve executed args and kwargs values from the environment
# Execute the method and return the result
@@ -229,10 +223,9 @@ class MetaInfoProp(torch.fx.Interpreter):
Return:
result (Any): The argument value that was retrieved
flop_count (Tuple): The flop count for (fwd_flop, bwd_flop).
mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out)
meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`.
"""
return args[0], (0, 0), (0, 0, 0, 0)
return args[0], GraphInfo(fwd_mem_in=activation_size(args[0]))
def propagate(self, *args):
"""