mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 19:13:01 +00:00
[fx] add rules to linearize computation graphs for searching. (#1461)
* [fx] polish ckpt_test. * [fx] add rules to linearize computation graphs for searching. * [fx] remove chen_sqrt for sake of simplicity * [fx] fix inconsistencies.
This commit is contained in:
@@ -1 +1 @@
|
||||
from .ckpt_solver_chen import chen_greedy, chen_sqrtn
|
||||
from .ckpt_solver_chen import chen_greedy
|
||||
|
@@ -1,16 +1,33 @@
|
||||
from typing import List, Set, Tuple
|
||||
import torch
|
||||
from torch.fx import GraphModule
|
||||
from torch.fx import GraphModule, Node
|
||||
import math
|
||||
|
||||
__all__ = ['chen_greedy', 'chen_sqrtn']
|
||||
__all__ = ['chen_greedy']
|
||||
CKPT_OP = ['call_module', 'call_method', 'call_function', 'get_attr']
|
||||
|
||||
|
||||
def _all_potential_ckpt_nodes(gm: GraphModule) -> List:
|
||||
"""
|
||||
In most existing frameworks of activation checkpoint, the forward graph is assumed to be linearized.
|
||||
"""
|
||||
|
||||
def is_sink():
|
||||
"""
|
||||
If we can free all memories when executing a certain node, it is a sink.
|
||||
"""
|
||||
return not sum((v for k, v in deps.items()))
|
||||
|
||||
deps = {}
|
||||
ckpt_nodes = []
|
||||
for n in gm.graph.nodes:
|
||||
if n.op == 'call_module':
|
||||
for n_par in n._input_nodes:
|
||||
deps[n_par] -= 1 # free memory and dependencies
|
||||
|
||||
# We can only put act_ckpt on these nodes
|
||||
if n.op in CKPT_OP and is_sink():
|
||||
ckpt_nodes.append(n)
|
||||
deps[n] = len(n.users) # add dependencies for future executions
|
||||
return ckpt_nodes
|
||||
|
||||
|
||||
@@ -71,32 +88,7 @@ def chen_greedy(gm: GraphModule) -> GraphModule:
|
||||
for i, seg in enumerate(ckpt):
|
||||
for idx in range(*seg):
|
||||
n = node_list[idx]
|
||||
if n.op in ['call_module', 'call_method', 'call_function']:
|
||||
setattr(n, 'activation_checkpoint', str(i))
|
||||
gm.recompile()
|
||||
return gm
|
||||
|
||||
|
||||
def chen_sqrtn(gm: GraphModule) -> GraphModule:
|
||||
"""
|
||||
This is the theoretical optimal strategy in https://arxiv.org/abs/1604.06174.
|
||||
|
||||
Usage:
|
||||
model = resnet18()
|
||||
input_sample = torch.rand(4, 3, 224, 224)
|
||||
gm = symbolic_trace(model)
|
||||
MetaInfoProp(gm).run(input_sample)
|
||||
gm = chen_sqrtn(gm)
|
||||
|
||||
Args:
|
||||
gm (GraphModule): The module to add checkpoints
|
||||
"""
|
||||
gm.graph.lint() # make sure nodes are in topological order
|
||||
k = int(len(gm.graph.nodes)**0.5) # take approximately sqrt(n) checkpoints
|
||||
for idx, n in enumerate(gm.graph.nodes):
|
||||
# We should not add act_ckpt to the placeholder
|
||||
# The last segment should not be checkpointed
|
||||
if n.op != 'placeholder' and (idx + 1) // k < k:
|
||||
setattr(n, 'activation_checkpoint', str((idx + 1) // k))
|
||||
if n.op in CKPT_OP:
|
||||
setattr(n, 'activation_checkpoint', i)
|
||||
gm.recompile()
|
||||
return gm
|
||||
|
Reference in New Issue
Block a user