mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 04:50:17 +00:00
[hotfix/rotor] fix variable names (#1597)
* [fx] add some comment and docstrings. * [fx] add dataflow analysis for an autograd graph. * add intepretation for graph analysis. * [fx] before doing save_tensor_hooks. * [fx] provide an accurate estimation of memory except for GPT-2. * [fx] provide an accurate estimation of memory except for GPT-2. * [fx] provide an accurate estimation of memory except for GPT-2. * [fx] a very accurate version on GPT-2. * [fx] refactor code. * [fx] remove redundant inplace=True. * [fx] refactor code. * [fx] refactor code. * [fx] refactor code. * [fx] dive into backward memory. * [fx] fix variable names in ckpt_solvers and unskip tests. * [fx] commit my changes. * [fx] restore skips. * [fx] restore skips. * [fx] chaange stage into phase. * [fx] chaange stage into phase. * [fx] chaange stage into phase.
This commit is contained in:
@@ -5,8 +5,8 @@ import torch
|
||||
from torch.fx import Graph, Node
|
||||
from torch.fx.node import Argument, Target
|
||||
from torch.utils._pytree import tree_map
|
||||
from .dataflow import autograd_graph_analysis, Stage
|
||||
from .memory import WEIRD_OPS
|
||||
from .dataflow import GraphInfo, autograd_graph_analysis, Phase
|
||||
from .memory import WEIRD_OPS, activation_size
|
||||
from .tensor import MetaTensor
|
||||
from .opcount import flop_mapping
|
||||
|
||||
@@ -41,14 +41,11 @@ def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ...
|
||||
|
||||
# `flop_count`` serves as a global dictionary to store results.
|
||||
flop_count = {
|
||||
Stage.FORWARD: 0,
|
||||
Stage.LOSS: 0,
|
||||
Stage.BACKWARD: 0,
|
||||
Phase.FORWARD: 0,
|
||||
Phase.LOSS: 0,
|
||||
Phase.BACKWARD: 0,
|
||||
}
|
||||
|
||||
# `stage` will mark the stage of autograd from outside scope.
|
||||
stage = Stage.FORWARD
|
||||
|
||||
# FlopTensor not only get the flop statistics of a single node,
|
||||
# it also build a full autograd graph for this node.
|
||||
# This makes sure we can analyze the dependencies of memory, and
|
||||
@@ -85,9 +82,9 @@ def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ...
|
||||
|
||||
# run aten for backend=CPU but actually on backend=Meta
|
||||
out = func(*args, **kwargs)
|
||||
flop_count[stage] += flop_mapping[func](args, normalize_tuple(out))
|
||||
flop_count[phase] += flop_mapping[func](args, normalize_tuple(out))
|
||||
node.meta['out'] = normalize_tuple(out)
|
||||
node.meta['stage'] = stage
|
||||
node.meta['phase'] = phase
|
||||
|
||||
def wrap(x):
|
||||
return FlopTensor(x.to('meta')) if isinstance(x, torch.Tensor) else x
|
||||
@@ -121,7 +118,7 @@ def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ...
|
||||
x._node = subgraph.create_node('placeholder',
|
||||
'placeholder', (subgraph._root,),
|
||||
name=subgraph._graph_namespace.create_name('input', x._tensor))
|
||||
x._node.meta['stage'] = Stage.PLACEHOLDER
|
||||
x._node.meta['phase'] = Phase.PLACEHOLDER
|
||||
x._node.meta['out'] = (x._tensor,)
|
||||
|
||||
tree_map(set_placeholder, args)
|
||||
@@ -135,6 +132,8 @@ def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ...
|
||||
def unpack(x):
|
||||
return x
|
||||
|
||||
# `phase` will mark the phase of autograd from outside scope.
|
||||
phase = Phase.FORWARD
|
||||
# mark saved tensors with saved_tensors_hooks
|
||||
with torch.autograd.graph.saved_tensors_hooks(pack, unpack):
|
||||
if isinstance(target, str):
|
||||
@@ -147,13 +146,13 @@ def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ...
|
||||
# If the output is not a floating point `torch.Tensor` or it does not
|
||||
# requires grad, then we should not run backward for this node.
|
||||
if is_autogradable(out) and out.requires_grad:
|
||||
stage = Stage.LOSS
|
||||
phase = Phase.LOSS
|
||||
loss = out.sum()
|
||||
stage = Stage.BACKWARD
|
||||
phase = Phase.BACKWARD
|
||||
loss.backward()
|
||||
|
||||
graph_info = autograd_graph_analysis(subgraph)
|
||||
graph_info.fwd_flop, graph_info.bwd_flop = flop_count[Stage.FORWARD], flop_count[Stage.BACKWARD]
|
||||
graph_info.fwd_flop, graph_info.bwd_flop = flop_count[Phase.FORWARD], flop_count[Phase.BACKWARD]
|
||||
|
||||
def unwrap(x):
|
||||
return x._tensor.to('meta') if isinstance(x, FlopTensor) else x
|
||||
@@ -180,6 +179,11 @@ def profile_function(target: 'Target') -> Callable:
|
||||
|
||||
# If there is an argument that this `call_function` is inplace, we should
|
||||
# skip the autograd profiling.
|
||||
if kwargs.get('inplace', False):
|
||||
args = tree_map(lambda x: x.to('meta') if isinstance(x, torch.Tensor) else x, args)
|
||||
kwargs = tree_map(lambda x: x.to('meta') if isinstance(x, torch.Tensor) else x, kwargs)
|
||||
out = func(*args, **kwargs)
|
||||
return out, GraphInfo(out.numel(), out.numel(), activation_size((args, kwargs)), 0, activation_size(out), 0)
|
||||
out, meta = _profile(func, *args, **kwargs)
|
||||
return out, meta
|
||||
|
||||
@@ -222,6 +226,11 @@ def profile_module(module: torch.nn.Module) -> Callable:
|
||||
|
||||
# If there is an argument that this `call_module` is inplace, we should
|
||||
# skip the autograd profiling.
|
||||
if getattr(module, 'inplace', False):
|
||||
args = tree_map(lambda x: x.to('meta'), args)
|
||||
kwargs = tree_map(lambda x: x.to('meta'), kwargs)
|
||||
out = func(*args, **kwargs)
|
||||
return out, GraphInfo(out.numel(), out.numel(), activation_size((args, kwargs)), 0, activation_size(out), 0)
|
||||
out, meta = _profile(func, *args, inplace=getattr(module, 'inplace', False), **kwargs)
|
||||
return out, meta
|
||||
|
||||
|
Reference in New Issue
Block a user