[hotfix/rotor] fix variable names (#1597)

* [fx] add some comment and docstrings.

* [fx] add dataflow analysis for an autograd graph.

* add intepretation for graph analysis.

* [fx] before doing save_tensor_hooks.

* [fx] provide an accurate estimation of memory except for GPT-2.

* [fx] provide an accurate estimation of memory except for GPT-2.

* [fx] provide an accurate estimation of memory except for GPT-2.

* [fx] a very accurate version on GPT-2.

* [fx] refactor code.

* [fx] remove redundant inplace=True.

* [fx] refactor code.

* [fx] refactor code.

* [fx] refactor code.

* [fx] dive into backward memory.

* [fx] fix variable names in ckpt_solvers and unskip tests.

* [fx] commit my changes.

* [fx] restore skips.

* [fx] restore skips.

* [fx] chaange stage into phase.

* [fx] chaange stage into phase.

* [fx] chaange stage into phase.
This commit is contained in:
Super Daniel
2022-09-14 14:27:04 +08:00
committed by GitHub
parent faa23b9d9a
commit c8e9b2ad78
7 changed files with 85 additions and 83 deletions

View File

@@ -5,8 +5,8 @@ import torch
from torch.fx import Graph, Node
from torch.fx.node import Argument, Target
from torch.utils._pytree import tree_map
from .dataflow import autograd_graph_analysis, Stage
from .memory import WEIRD_OPS
from .dataflow import GraphInfo, autograd_graph_analysis, Phase
from .memory import WEIRD_OPS, activation_size
from .tensor import MetaTensor
from .opcount import flop_mapping
@@ -41,14 +41,11 @@ def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ...
# `flop_count`` serves as a global dictionary to store results.
flop_count = {
Stage.FORWARD: 0,
Stage.LOSS: 0,
Stage.BACKWARD: 0,
Phase.FORWARD: 0,
Phase.LOSS: 0,
Phase.BACKWARD: 0,
}
# `stage` will mark the stage of autograd from outside scope.
stage = Stage.FORWARD
# FlopTensor not only get the flop statistics of a single node,
# it also build a full autograd graph for this node.
# This makes sure we can analyze the dependencies of memory, and
@@ -85,9 +82,9 @@ def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ...
# run aten for backend=CPU but actually on backend=Meta
out = func(*args, **kwargs)
flop_count[stage] += flop_mapping[func](args, normalize_tuple(out))
flop_count[phase] += flop_mapping[func](args, normalize_tuple(out))
node.meta['out'] = normalize_tuple(out)
node.meta['stage'] = stage
node.meta['phase'] = phase
def wrap(x):
return FlopTensor(x.to('meta')) if isinstance(x, torch.Tensor) else x
@@ -121,7 +118,7 @@ def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ...
x._node = subgraph.create_node('placeholder',
'placeholder', (subgraph._root,),
name=subgraph._graph_namespace.create_name('input', x._tensor))
x._node.meta['stage'] = Stage.PLACEHOLDER
x._node.meta['phase'] = Phase.PLACEHOLDER
x._node.meta['out'] = (x._tensor,)
tree_map(set_placeholder, args)
@@ -135,6 +132,8 @@ def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ...
def unpack(x):
return x
# `phase` will mark the phase of autograd from outside scope.
phase = Phase.FORWARD
# mark saved tensors with saved_tensors_hooks
with torch.autograd.graph.saved_tensors_hooks(pack, unpack):
if isinstance(target, str):
@@ -147,13 +146,13 @@ def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ...
# If the output is not a floating point `torch.Tensor` or it does not
# requires grad, then we should not run backward for this node.
if is_autogradable(out) and out.requires_grad:
stage = Stage.LOSS
phase = Phase.LOSS
loss = out.sum()
stage = Stage.BACKWARD
phase = Phase.BACKWARD
loss.backward()
graph_info = autograd_graph_analysis(subgraph)
graph_info.fwd_flop, graph_info.bwd_flop = flop_count[Stage.FORWARD], flop_count[Stage.BACKWARD]
graph_info.fwd_flop, graph_info.bwd_flop = flop_count[Phase.FORWARD], flop_count[Phase.BACKWARD]
def unwrap(x):
return x._tensor.to('meta') if isinstance(x, FlopTensor) else x
@@ -180,6 +179,11 @@ def profile_function(target: 'Target') -> Callable:
# If there is an argument that this `call_function` is inplace, we should
# skip the autograd profiling.
if kwargs.get('inplace', False):
args = tree_map(lambda x: x.to('meta') if isinstance(x, torch.Tensor) else x, args)
kwargs = tree_map(lambda x: x.to('meta') if isinstance(x, torch.Tensor) else x, kwargs)
out = func(*args, **kwargs)
return out, GraphInfo(out.numel(), out.numel(), activation_size((args, kwargs)), 0, activation_size(out), 0)
out, meta = _profile(func, *args, **kwargs)
return out, meta
@@ -222,6 +226,11 @@ def profile_module(module: torch.nn.Module) -> Callable:
# If there is an argument that this `call_module` is inplace, we should
# skip the autograd profiling.
if getattr(module, 'inplace', False):
args = tree_map(lambda x: x.to('meta'), args)
kwargs = tree_map(lambda x: x.to('meta'), kwargs)
out = func(*args, **kwargs)
return out, GraphInfo(out.numel(), out.numel(), activation_size((args, kwargs)), 0, activation_size(out), 0)
out, meta = _profile(func, *args, inplace=getattr(module, 'inplace', False), **kwargs)
return out, meta