mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-06 10:34:23 +00:00
[fx] Add nested checkpoint in activation checkpoint codegen (#1585)
* [fx] add nested activation_checkpoint codegen * undo algorithms commits * solver * undo some commits * [fx] torch11 add nested activation checkpoint codegen * remove some imports * [fx] add some comments in activation codegen * [fx] codegen instance error fix
This commit is contained in:
parent
1c9ec32734
commit
f3687e4ee2
@ -109,6 +109,209 @@ def _gen_ckpt_usage(label, activation_offload, input_vars, output_vars, use_reen
|
|||||||
return f'{outputs} = colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_{label}, {activation_offload}, {inputs}, use_reentrant={use_reentrant})'
|
return f'{outputs} = colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_{label}, {activation_offload}, {inputs}, use_reentrant={use_reentrant})'
|
||||||
|
|
||||||
|
|
||||||
|
def _end_of_ckpt(node: Node, check_idx: int) -> bool:
|
||||||
|
"""Check if the node could end the ckpt region
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node (Node): torch.fx.Node
|
||||||
|
check_idx (int): the index of checkpoint level for
|
||||||
|
nested checkpoint
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool
|
||||||
|
"""
|
||||||
|
if hasattr(node, "activation_checkpoint"):
|
||||||
|
if isinstance(node.activation_checkpoint, list):
|
||||||
|
return node.activation_checkpoint[check_idx] == None
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def _find_nested_ckpt_regions(nodes, check_idx=0):
|
||||||
|
"""
|
||||||
|
Find the nested checkpoint regions given a list of consecutive nodes. The outputs
|
||||||
|
will be list of tuples, each tuple is in the form of (start_index, end_index).
|
||||||
|
"""
|
||||||
|
ckpt_regions = []
|
||||||
|
start = -1
|
||||||
|
end = -1
|
||||||
|
current_region = None
|
||||||
|
|
||||||
|
for idx, node in enumerate(nodes):
|
||||||
|
if hasattr(node, 'activation_checkpoint'):
|
||||||
|
if isinstance(getattr(node, 'activation_checkpoint'), int):
|
||||||
|
act_ckpt_label = node.activation_checkpoint
|
||||||
|
else:
|
||||||
|
act_ckpt_label = node.activation_checkpoint[check_idx]
|
||||||
|
|
||||||
|
# this activation checkpoint label is not set yet
|
||||||
|
# meaning this is the first node of the activation ckpt region
|
||||||
|
if current_region is None:
|
||||||
|
current_region = act_ckpt_label
|
||||||
|
start = idx
|
||||||
|
|
||||||
|
# if activation checkpoint has changed
|
||||||
|
# we restart the tracking
|
||||||
|
# e.g. node ckpt states = [ckpt1, ckpt2, ckpt2, ckpt2]
|
||||||
|
if act_ckpt_label != current_region:
|
||||||
|
assert start != -1
|
||||||
|
ckpt_regions.append((start, idx - 1))
|
||||||
|
current_region = act_ckpt_label
|
||||||
|
start = idx
|
||||||
|
end = -1
|
||||||
|
elif current_region is not None and _end_of_ckpt(node, check_idx):
|
||||||
|
# used to check the case below
|
||||||
|
# node ckpt states = [ckpt, ckpt, non-ckpt]
|
||||||
|
end = idx - 1
|
||||||
|
assert start != -1 and end != -1
|
||||||
|
ckpt_regions.append((start, end))
|
||||||
|
start = end = -1
|
||||||
|
current_region = None
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if current_region is not None:
|
||||||
|
end = len(nodes) - 1
|
||||||
|
ckpt_regions.append((start, end))
|
||||||
|
return ckpt_regions
|
||||||
|
|
||||||
|
|
||||||
|
def emit_ckpt_func(body,
|
||||||
|
ckpt_func,
|
||||||
|
node_list: List[Node],
|
||||||
|
emit_node_func,
|
||||||
|
delete_unused_value_func,
|
||||||
|
level=0,
|
||||||
|
in_ckpt=False):
|
||||||
|
"""Emit ckpt fuction in nested way
|
||||||
|
|
||||||
|
Args:
|
||||||
|
body: forward code, in recursive calls, this part will be checkpoint
|
||||||
|
functions code
|
||||||
|
ckpt_func: checkpoint functions code, in recursive calls, this part
|
||||||
|
will be a buffer
|
||||||
|
node_list (List[Node]): list of torch.fx.Node
|
||||||
|
emit_node_func: function to emit a node
|
||||||
|
delete_unused_value_func: function to delete unused value
|
||||||
|
level (int, optional): checkpoint level. Defaults to 0.
|
||||||
|
in_ckpt (bool, optional): indicates wether the func is in recursive
|
||||||
|
call. Defaults to False.
|
||||||
|
"""
|
||||||
|
inputs, outputs = _find_input_and_output_nodes(node_list)
|
||||||
|
|
||||||
|
# if the current checkpoint function use int as label, using old generation method
|
||||||
|
if isinstance(node_list[0].activation_checkpoint, int):
|
||||||
|
label = node_list[0].activation_checkpoint
|
||||||
|
ckpt_fn_def = _gen_ckpt_fn_def(label, inputs)
|
||||||
|
ckpt_func.append(f'{ckpt_fn_def}\n')
|
||||||
|
for node in node_list:
|
||||||
|
emit_node_func(node, ckpt_func)
|
||||||
|
ckpt_func[-1] = ' ' + ckpt_func[-1]
|
||||||
|
delete_unused_value_func(node, ckpt_func)
|
||||||
|
|
||||||
|
ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n')
|
||||||
|
activation_offload = getattr(node_list[0], "activation_offload", False)
|
||||||
|
usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False)
|
||||||
|
usage += "\n"
|
||||||
|
body.append(usage)
|
||||||
|
|
||||||
|
# use nested ckpt function codegen
|
||||||
|
else:
|
||||||
|
# label given by each layer, e.g. if you are currently at level [0, 1, 1]
|
||||||
|
# the label will be '0_1_1'
|
||||||
|
label = "_".join([str(idx) for idx in node_list[0].activation_checkpoint[:level + 1]])
|
||||||
|
ckpt_fn_def = _gen_ckpt_fn_def(label, inputs)
|
||||||
|
ckpt_func.append(f'{ckpt_fn_def}\n')
|
||||||
|
|
||||||
|
# if there is more level to fetch
|
||||||
|
if level + 1 < len(node_list[0].activation_checkpoint):
|
||||||
|
ckpt_regions = _find_nested_ckpt_regions(node_list, level + 1)
|
||||||
|
start_idx = [item[0] for item in ckpt_regions]
|
||||||
|
end_idx = [item[1] for item in ckpt_regions]
|
||||||
|
|
||||||
|
# use ckpt_func_buffer to store nested checkpoint functions
|
||||||
|
ckpt_func_buffer = []
|
||||||
|
node_idx = 0
|
||||||
|
while 1:
|
||||||
|
if node_idx >= len(node_list):
|
||||||
|
break
|
||||||
|
|
||||||
|
if node_idx in start_idx:
|
||||||
|
ckpt_node_list = node_list[node_idx:end_idx[start_idx.index(node_idx)] + 1]
|
||||||
|
emit_ckpt_func(ckpt_func, ckpt_func_buffer, ckpt_node_list, emit_node_func,
|
||||||
|
delete_unused_value_func, level + 1, True)
|
||||||
|
node_idx += len(ckpt_node_list)
|
||||||
|
|
||||||
|
else:
|
||||||
|
node = node_list[node_idx]
|
||||||
|
emit_node_func(node, ckpt_func)
|
||||||
|
ckpt_func[-1] = ' ' + ckpt_func[-1]
|
||||||
|
delete_unused_value_func(node, ckpt_func)
|
||||||
|
node_idx += 1
|
||||||
|
|
||||||
|
ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n')
|
||||||
|
ckpt_func += ckpt_func_buffer
|
||||||
|
activation_offload = getattr(node_list[0], "activation_offload", False)
|
||||||
|
usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + '\n'
|
||||||
|
if in_ckpt:
|
||||||
|
usage = ' ' + usage
|
||||||
|
body.append(usage)
|
||||||
|
|
||||||
|
# last level
|
||||||
|
else:
|
||||||
|
for node in node_list:
|
||||||
|
emit_node_func(node, ckpt_func)
|
||||||
|
ckpt_func[-1] = ' ' + ckpt_func[-1]
|
||||||
|
delete_unused_value_func(node, ckpt_func)
|
||||||
|
|
||||||
|
ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n')
|
||||||
|
activation_offload = getattr(node_list[0], "activation_offload", False)
|
||||||
|
usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + '\n'
|
||||||
|
if in_ckpt:
|
||||||
|
usage = ' ' + usage
|
||||||
|
body.append(usage)
|
||||||
|
|
||||||
|
|
||||||
|
def emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, delete_unused_value_func):
|
||||||
|
"""Emit code with nested activation checkpoint
|
||||||
|
When we detect some of the node.activation_checkpoint is a List, we will use
|
||||||
|
this function to emit the activation checkpoint codes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
body: forward code
|
||||||
|
ckpt_func: checkpoint functions code
|
||||||
|
nodes: graph.nodes
|
||||||
|
emit_node_func: function to emit node
|
||||||
|
delete_unused_value_func: function to remove the unused value
|
||||||
|
"""
|
||||||
|
ckpt_regions = _find_nested_ckpt_regions(nodes, 0)
|
||||||
|
start_idx = [item[0] for item in ckpt_regions]
|
||||||
|
end_idx = [item[1] for item in ckpt_regions]
|
||||||
|
|
||||||
|
node_list = list(nodes)
|
||||||
|
|
||||||
|
node_idx = 0
|
||||||
|
while 1:
|
||||||
|
# break if we finish the processing all the nodes
|
||||||
|
if node_idx >= len(node_list):
|
||||||
|
break
|
||||||
|
|
||||||
|
# process ckpt_regions
|
||||||
|
if node_idx in start_idx:
|
||||||
|
ckpt_node_list = node_list[node_idx:end_idx[start_idx.index(node_idx)] + 1]
|
||||||
|
emit_ckpt_func(body, ckpt_func, ckpt_node_list, emit_node_func, delete_unused_value_func)
|
||||||
|
node_idx += len(ckpt_node_list)
|
||||||
|
|
||||||
|
# process node in forward function
|
||||||
|
else:
|
||||||
|
node = node_list[node_idx]
|
||||||
|
emit_node_func(node, body)
|
||||||
|
delete_unused_value_func(node, body)
|
||||||
|
node_idx += 1
|
||||||
|
|
||||||
|
|
||||||
def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, delete_unused_value_func):
|
def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, delete_unused_value_func):
|
||||||
# find the activation checkpoint regions
|
# find the activation checkpoint regions
|
||||||
ckpt_regions = _find_ckpt_regions(nodes)
|
ckpt_regions = _find_ckpt_regions(nodes)
|
||||||
@ -384,7 +587,10 @@ if CODEGEN_AVAILABLE:
|
|||||||
|
|
||||||
# Modified for activation checkpointing
|
# Modified for activation checkpointing
|
||||||
ckpt_func = []
|
ckpt_func = []
|
||||||
emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values)
|
if all(not isinstance(getattr(node, "activation_checkpoint", None), list) for node in nodes):
|
||||||
|
emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values)
|
||||||
|
else:
|
||||||
|
emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values)
|
||||||
|
|
||||||
if len(body) == 0:
|
if len(body) == 0:
|
||||||
# If the Graph has no non-placeholder nodes, no lines for the body
|
# If the Graph has no non-placeholder nodes, no lines for the body
|
||||||
@ -582,7 +788,10 @@ else:
|
|||||||
|
|
||||||
# Modified for activation checkpointing
|
# Modified for activation checkpointing
|
||||||
ckpt_func = []
|
ckpt_func = []
|
||||||
emit_code_with_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values)
|
if all(not isinstance(getattr(node, "activation_checkpoint", None), list) for node in self.nodes):
|
||||||
|
emit_code_with_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values)
|
||||||
|
else:
|
||||||
|
emit_code_with_nested_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values)
|
||||||
|
|
||||||
if len(body) == 0:
|
if len(body) == 0:
|
||||||
# If the Graph has no non-placeholder nodes, no lines for the body
|
# If the Graph has no non-placeholder nodes, no lines for the body
|
||||||
|
@ -0,0 +1,153 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import pytest
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
from torch.utils.checkpoint import checkpoint
|
||||||
|
from torch.fx import GraphModule
|
||||||
|
from colossalai.fx import ColoTracer
|
||||||
|
import colossalai
|
||||||
|
from colossalai.utils import free_port
|
||||||
|
from colossalai.core import global_context as gpc
|
||||||
|
from colossalai.fx.graph_module import ColoGraphModule
|
||||||
|
|
||||||
|
try:
|
||||||
|
from colossalai.fx.codegen import ActivationCheckpointCodeGen
|
||||||
|
with_codegen = True
|
||||||
|
except:
|
||||||
|
# fall back to older pytorch version
|
||||||
|
from colossalai.fx.codegen import python_code_with_activation_checkpoint
|
||||||
|
with_codegen = False
|
||||||
|
|
||||||
|
|
||||||
|
class MyModule(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.linear1 = torch.nn.Linear(4, 4)
|
||||||
|
self.linear2 = torch.nn.Linear(4, 4)
|
||||||
|
self.linear3 = torch.nn.Linear(4, 4)
|
||||||
|
self.linear4 = torch.nn.Linear(4, 4)
|
||||||
|
self.linear5 = torch.nn.Linear(4, 4)
|
||||||
|
self.linear6 = torch.nn.Linear(4, 4)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.linear6(self.linear5(self.linear4(self.linear3(self.linear2(self.linear1(x))))))
|
||||||
|
|
||||||
|
|
||||||
|
def _run_act_ckpt_codegen(rank):
|
||||||
|
# launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly
|
||||||
|
colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl')
|
||||||
|
|
||||||
|
# build model and run forward
|
||||||
|
model = MyModule()
|
||||||
|
data1 = torch.rand(4, 4)
|
||||||
|
|
||||||
|
# copy model to cuda
|
||||||
|
model = model.to(device="cuda")
|
||||||
|
data1 = data1.to(device="cuda")
|
||||||
|
|
||||||
|
non_fx_out = model(data1)
|
||||||
|
|
||||||
|
# trace the module and replace codegen
|
||||||
|
tracer = ColoTracer(trace_act_ckpt=True)
|
||||||
|
graph = tracer.trace(model)
|
||||||
|
codegen = ActivationCheckpointCodeGen()
|
||||||
|
graph.set_codegen(codegen)
|
||||||
|
|
||||||
|
# annotate nested checkpoint
|
||||||
|
for node in graph.nodes:
|
||||||
|
if node.name == "linear1":
|
||||||
|
setattr(node, "activation_checkpoint", [0, 0, 0])
|
||||||
|
continue
|
||||||
|
if node.name == "linear2":
|
||||||
|
setattr(node, "activation_checkpoint", [0, 0, None])
|
||||||
|
if node.name == "linear3":
|
||||||
|
setattr(node, "activation_checkpoint", [0, 0, 1])
|
||||||
|
if node.name == "linear4":
|
||||||
|
setattr(node, "activation_checkpoint", [0, 1, None])
|
||||||
|
if node.name == "linear5":
|
||||||
|
setattr(node, "activation_checkpoint", 1)
|
||||||
|
gm = ColoGraphModule(model, graph)
|
||||||
|
gm.recompile()
|
||||||
|
|
||||||
|
# assert checkpoint function will be generated and
|
||||||
|
code = graph.python_code('self').src
|
||||||
|
assert 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0, False, x, use_reentrant=False)' in code and \
|
||||||
|
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_1, False, linear3, use_reentrant=False)' in code and \
|
||||||
|
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0_0, False, x, use_reentrant=False)' in code and \
|
||||||
|
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0_1, False, linear2, use_reentrant=False)' in code and \
|
||||||
|
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, False, x, use_reentrant=False)' in code and \
|
||||||
|
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_1, False, linear4, use_reentrant=False)' in code
|
||||||
|
|
||||||
|
# recompile and verify the outputs are consistent
|
||||||
|
fx_out = gm(data1)
|
||||||
|
assert torch.equal(non_fx_out, fx_out)
|
||||||
|
|
||||||
|
gpc.destroy()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0')
|
||||||
|
def test_act_ckpt_codegen():
|
||||||
|
mp.spawn(_run_act_ckpt_codegen, nprocs=1)
|
||||||
|
|
||||||
|
|
||||||
|
def _run_act_ckpt_python_code_torch11(rank):
|
||||||
|
# launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly
|
||||||
|
colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl')
|
||||||
|
|
||||||
|
# build model and run forward
|
||||||
|
model = MyModule()
|
||||||
|
data1 = torch.rand(4, 4)
|
||||||
|
|
||||||
|
# copy model to cuda
|
||||||
|
model = model.to(device="cuda")
|
||||||
|
data1 = data1.to(device="cuda")
|
||||||
|
|
||||||
|
non_fx_out = model(data1)
|
||||||
|
|
||||||
|
# trace the module and replace codegen
|
||||||
|
tracer = ColoTracer(trace_act_ckpt=True)
|
||||||
|
graph = tracer.trace(model)
|
||||||
|
codegen = ActivationCheckpointCodeGen()
|
||||||
|
graph.set_codegen(codegen)
|
||||||
|
|
||||||
|
# annotate nested checkpoint
|
||||||
|
for node in graph.nodes:
|
||||||
|
if node.name == "linear1":
|
||||||
|
setattr(node, "activation_checkpoint", [0, 0, 0])
|
||||||
|
continue
|
||||||
|
if node.name == "linear2":
|
||||||
|
setattr(node, "activation_checkpoint", [0, 0, None])
|
||||||
|
if node.name == "linear3":
|
||||||
|
setattr(node, "activation_checkpoint", [0, 0, 1])
|
||||||
|
if node.name == "linear4":
|
||||||
|
setattr(node, "activation_checkpoint", [0, 1, None])
|
||||||
|
if node.name == "linear5":
|
||||||
|
setattr(node, "activation_checkpoint", 1)
|
||||||
|
gm = ColoGraphModule(model, graph)
|
||||||
|
gm.recompile()
|
||||||
|
|
||||||
|
# assert checkpoint function will be generated and
|
||||||
|
code = graph.python_code('self').src
|
||||||
|
assert 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0, False, x, use_reentrant=False)' in code and \
|
||||||
|
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_1, False, linear3, use_reentrant=False)' in code and \
|
||||||
|
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0_0, False, x, use_reentrant=False)' in code and \
|
||||||
|
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0_1, False, linear2, use_reentrant=False)' in code and \
|
||||||
|
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, False, x, use_reentrant=False)' in code and \
|
||||||
|
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_1, False, linear4, use_reentrant=False)' in code
|
||||||
|
|
||||||
|
# recompile and verify the outputs are consistent
|
||||||
|
fx_out = gm(data1)
|
||||||
|
assert torch.equal(non_fx_out, fx_out)
|
||||||
|
|
||||||
|
gpc.destroy()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0')
|
||||||
|
@pytest.mark.skip(reason="currently torch11 ColoGraphModule is not done")
|
||||||
|
def test_act_ckpt_python_code_torch11():
|
||||||
|
mp.spawn(_run_act_ckpt_python_code_torch11, nprocs=1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
_run_act_ckpt_codegen(rank=0)
|
Loading…
Reference in New Issue
Block a user