mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-12 12:47:21 +00:00
[fx] provide an accurate estimation of memory. (#1587)
* [fx] add some comment and docstrings. * [fx] add dataflow analysis for an autograd graph. * add intepretation for graph analysis. * [fx] before doing save_tensor_hooks. * [fx] provide an accurate estimation of memory except for GPT-2. * [fx] provide an accurate estimation of memory except for GPT-2. * [fx] provide an accurate estimation of memory except for GPT-2. * [fx] a very accurate version on GPT-2. * [fx] refactor code. * [fx] remove redundant inplace=True. * [fx] refactor code. * [fx] refactor code. * [fx] refactor code. * [fx] dive into backward memory.
This commit is contained in:
@@ -1,10 +1,12 @@
|
||||
from dataclasses import asdict
|
||||
from colossalai.fx.profiler import GraphInfo
|
||||
import torch
|
||||
import torch.fx
|
||||
from torch.fx.node import Node, Argument, Target
|
||||
from torch.utils._pytree import tree_map
|
||||
from typing import Any, Tuple, NamedTuple, Dict
|
||||
from torch.fx._compatibility import compatibility
|
||||
from colossalai.fx.profiler import profile_function, profile_module, profile_method, activation_size, parameter_size
|
||||
from colossalai.fx.profiler import profile_function, profile_module, profile_method, activation_size
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
@@ -40,7 +42,7 @@ def _extract_tensor_metadata(result: torch.Tensor) -> TensorMetadata:
|
||||
class MetaInfoProp(torch.fx.Interpreter):
|
||||
"""
|
||||
Execute an FX graph Node-by-Node with meta tensor and
|
||||
record the shape, FLOPs, MACs and type of the result
|
||||
record the memory usage, FLOPs, and type of the result
|
||||
into the corresponding node.
|
||||
|
||||
Usage:
|
||||
@@ -82,7 +84,7 @@ class MetaInfoProp(torch.fx.Interpreter):
|
||||
Returns:
|
||||
Any: The result of executing ``n``
|
||||
"""
|
||||
result, flop_count, mem_stat = super().run_node(n)
|
||||
result, meta_info = super().run_node(n)
|
||||
|
||||
def extract_tensor_meta(obj):
|
||||
if isinstance(obj, torch.Tensor):
|
||||
@@ -90,21 +92,20 @@ class MetaInfoProp(torch.fx.Interpreter):
|
||||
else:
|
||||
return TensorMetadata(None, None, False, None, 0, False)
|
||||
|
||||
meta = tree_map(extract_tensor_meta, result)
|
||||
n.meta['tensor_meta'] = meta
|
||||
tensor_meta = tree_map(extract_tensor_meta, result)
|
||||
n.meta['tensor_meta'] = tensor_meta
|
||||
n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta`
|
||||
|
||||
# TODO: the attribute node_size should be removed in the future
|
||||
setattr(n, 'node_size', mem_stat[1])
|
||||
setattr(n, 'fwd_flop', flop_count[0])
|
||||
setattr(n, 'bwd_flop', flop_count[1])
|
||||
setattr(n, 'fwd_tmp', mem_stat[0])
|
||||
setattr(n, 'fwd_out', mem_stat[1])
|
||||
setattr(n, 'bwd_tmp', mem_stat[2])
|
||||
setattr(n, 'bwd_out', mem_stat[3])
|
||||
setattr(n, 'node_size', n.meta.get('fwd_mem_tmp', 0) + n.meta.get('fwd_mem_out', 0))
|
||||
for par in n.all_input_nodes:
|
||||
par.meta['fwd_mem_out'] = par.meta.get('fwd_mem_out', 0) + n.meta.get('fwd_mem_in', 0)
|
||||
n.meta['type'] = type(result)
|
||||
|
||||
# retain the autograd graph
|
||||
for param in self.module.parameters():
|
||||
param.grad = None
|
||||
|
||||
return result
|
||||
|
||||
# Main Node running APIs
|
||||
@@ -125,12 +126,9 @@ class MetaInfoProp(torch.fx.Interpreter):
|
||||
|
||||
Returns:
|
||||
result (Any): The argument value that was retrieved
|
||||
flop_count (Tuple): The flop count for (fwd_flop, bwd_flop).
|
||||
mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out)
|
||||
meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`.
|
||||
"""
|
||||
result = super().placeholder(target, args, kwargs)
|
||||
# A placeholder node only has activation
|
||||
return result, (0, 0), (0, activation_size(result), 0, 0)
|
||||
return super().placeholder(target, args, kwargs), GraphInfo()
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def get_attr(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
@@ -147,10 +145,9 @@ class MetaInfoProp(torch.fx.Interpreter):
|
||||
|
||||
Return:
|
||||
result (Any): The argument value that was retrieved
|
||||
flop_count (Tuple): The flop count for (fwd_flop, bwd_flop).
|
||||
mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out)
|
||||
meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`.
|
||||
"""
|
||||
return super().get_attr(target, args, kwargs), (0, 0), (0, 0, 0, 0)
|
||||
return super().get_attr(target, args, kwargs), GraphInfo()
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
@@ -166,8 +163,7 @@ class MetaInfoProp(torch.fx.Interpreter):
|
||||
|
||||
Return
|
||||
result (Any): The argument value that was retrieved
|
||||
flop_count (Tuple): The flop count for (fwd_flop, bwd_flop).
|
||||
mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out)
|
||||
meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`.
|
||||
"""
|
||||
assert not isinstance(target, str)
|
||||
return profile_function(target)(*args, **kwargs)
|
||||
@@ -186,8 +182,7 @@ class MetaInfoProp(torch.fx.Interpreter):
|
||||
|
||||
Return
|
||||
result (Any): The argument value that was retrieved
|
||||
flop_count (Tuple): The flop count for (fwd_flop, bwd_flop).
|
||||
mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out)
|
||||
meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`.
|
||||
"""
|
||||
return profile_method(target)(*args, **kwargs)
|
||||
|
||||
@@ -205,8 +200,7 @@ class MetaInfoProp(torch.fx.Interpreter):
|
||||
|
||||
Return
|
||||
result (Any): The argument value that was retrieved
|
||||
flop_count (Tuple): The flop count for (fwd_flop, bwd_flop).
|
||||
mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out)
|
||||
meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`.
|
||||
"""
|
||||
# Retrieve executed args and kwargs values from the environment
|
||||
# Execute the method and return the result
|
||||
@@ -229,10 +223,9 @@ class MetaInfoProp(torch.fx.Interpreter):
|
||||
|
||||
Return:
|
||||
result (Any): The argument value that was retrieved
|
||||
flop_count (Tuple): The flop count for (fwd_flop, bwd_flop).
|
||||
mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out)
|
||||
meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`.
|
||||
"""
|
||||
return args[0], (0, 0), (0, 0, 0, 0)
|
||||
return args[0], GraphInfo(fwd_mem_in=activation_size(args[0]))
|
||||
|
||||
def propagate(self, *args):
|
||||
"""
|
||||
|
Reference in New Issue
Block a user