mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-11 21:01:54 +00:00
[fx] Add offload codegen (#1598)
* [fx] add input activation offload to codegen * [fx] modify unit test * [fx] remove two skips in torch11 * [fx] use all_input_nodes instead of _input_nodes
This commit is contained in:
parent
c8e9b2ad78
commit
a7cda6f57d
@ -17,6 +17,38 @@ else:
|
|||||||
__all__ = ['python_code_with_activation_checkpoint']
|
__all__ = ['python_code_with_activation_checkpoint']
|
||||||
|
|
||||||
|
|
||||||
|
def _gen_saved_tensors_hooks():
|
||||||
|
"""
|
||||||
|
Generate saved tensors hooks
|
||||||
|
"""
|
||||||
|
|
||||||
|
pack_hook = """def pack_hook(self, x):
|
||||||
|
if getattr(x, "offload", None):
|
||||||
|
return (x.device, x.cpu())
|
||||||
|
else:
|
||||||
|
return x
|
||||||
|
"""
|
||||||
|
|
||||||
|
unpack_hook = """def unpack_hook(self, packed):
|
||||||
|
if isinstance(packed, tuple):
|
||||||
|
device, tensor = packed
|
||||||
|
return tensor.to(device)
|
||||||
|
else:
|
||||||
|
return packed
|
||||||
|
"""
|
||||||
|
|
||||||
|
return pack_hook, unpack_hook
|
||||||
|
|
||||||
|
|
||||||
|
def _gen_save_tensors_hooks_context():
|
||||||
|
"""
|
||||||
|
Generate save tensors hooks context
|
||||||
|
"""
|
||||||
|
|
||||||
|
context = "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook, self.unpack_hook):\n"
|
||||||
|
return context
|
||||||
|
|
||||||
|
|
||||||
def _find_input_and_output_nodes(nodes: List[Node]):
|
def _find_input_and_output_nodes(nodes: List[Node]):
|
||||||
"""
|
"""
|
||||||
Find the input and output node names which are not found in the given list of nodes.
|
Find the input and output node names which are not found in the given list of nodes.
|
||||||
@ -211,7 +243,7 @@ def emit_ckpt_func(body,
|
|||||||
ckpt_func[-1] = ' ' + ckpt_func[-1]
|
ckpt_func[-1] = ' ' + ckpt_func[-1]
|
||||||
delete_unused_value_func(node, ckpt_func)
|
delete_unused_value_func(node, ckpt_func)
|
||||||
|
|
||||||
ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n')
|
ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n')
|
||||||
activation_offload = getattr(node_list[0], "activation_offload", False)
|
activation_offload = getattr(node_list[0], "activation_offload", False)
|
||||||
usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False)
|
usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False)
|
||||||
usage += "\n"
|
usage += "\n"
|
||||||
@ -251,7 +283,7 @@ def emit_ckpt_func(body,
|
|||||||
delete_unused_value_func(node, ckpt_func)
|
delete_unused_value_func(node, ckpt_func)
|
||||||
node_idx += 1
|
node_idx += 1
|
||||||
|
|
||||||
ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n')
|
ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n')
|
||||||
ckpt_func += ckpt_func_buffer
|
ckpt_func += ckpt_func_buffer
|
||||||
activation_offload = getattr(node_list[0], "activation_offload", False)
|
activation_offload = getattr(node_list[0], "activation_offload", False)
|
||||||
usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + '\n'
|
usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + '\n'
|
||||||
@ -266,7 +298,7 @@ def emit_ckpt_func(body,
|
|||||||
ckpt_func[-1] = ' ' + ckpt_func[-1]
|
ckpt_func[-1] = ' ' + ckpt_func[-1]
|
||||||
delete_unused_value_func(node, ckpt_func)
|
delete_unused_value_func(node, ckpt_func)
|
||||||
|
|
||||||
ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n')
|
ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n')
|
||||||
activation_offload = getattr(node_list[0], "activation_offload", False)
|
activation_offload = getattr(node_list[0], "activation_offload", False)
|
||||||
usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + '\n'
|
usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + '\n'
|
||||||
if in_ckpt:
|
if in_ckpt:
|
||||||
@ -292,6 +324,9 @@ def emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_nod
|
|||||||
|
|
||||||
node_list = list(nodes)
|
node_list = list(nodes)
|
||||||
|
|
||||||
|
# this flag is to prevent repeated insert of save tensors
|
||||||
|
# hooks definition in ckpt_func
|
||||||
|
is_hook_inserted = False
|
||||||
node_idx = 0
|
node_idx = 0
|
||||||
while 1:
|
while 1:
|
||||||
# break if we finish the processing all the nodes
|
# break if we finish the processing all the nodes
|
||||||
@ -307,6 +342,25 @@ def emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_nod
|
|||||||
# process node in forward function
|
# process node in forward function
|
||||||
else:
|
else:
|
||||||
node = node_list[node_idx]
|
node = node_list[node_idx]
|
||||||
|
|
||||||
|
# if a node is outside of checkpoint region and want to offload
|
||||||
|
# it's input activation, we will use torch.saved_tensors_hooks
|
||||||
|
# to complete the offload process.
|
||||||
|
if getattr(node, "activation_offload", False):
|
||||||
|
if not is_hook_inserted:
|
||||||
|
pack_hook, unpack_hook = _gen_saved_tensors_hooks()
|
||||||
|
ckpt_func.insert(0, "\n".join([pack_hook, unpack_hook]) + "\n")
|
||||||
|
|
||||||
|
for par in node.all_input_nodes:
|
||||||
|
# annotate the input tensor for pack hook
|
||||||
|
body.append(f"setattr({repr(par)}, 'offload', True)\n")
|
||||||
|
|
||||||
|
body.append(_gen_save_tensors_hooks_context())
|
||||||
|
emit_node_func(node, body)
|
||||||
|
body[-1] = ' ' + body[-1]
|
||||||
|
delete_unused_value_func(node, body)
|
||||||
|
|
||||||
|
else:
|
||||||
emit_node_func(node, body)
|
emit_node_func(node, body)
|
||||||
delete_unused_value_func(node, body)
|
delete_unused_value_func(node, body)
|
||||||
node_idx += 1
|
node_idx += 1
|
||||||
@ -323,6 +377,10 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
|
|||||||
|
|
||||||
node_list = list(nodes)
|
node_list = list(nodes)
|
||||||
|
|
||||||
|
# use this variable to avoid inserting hook functions
|
||||||
|
# to ckpt_func repeatedly
|
||||||
|
is_hook_inserted = False
|
||||||
|
|
||||||
# find the input and output var names for each region
|
# find the input and output var names for each region
|
||||||
for idx, (start, end) in enumerate(ckpt_regions):
|
for idx, (start, end) in enumerate(ckpt_regions):
|
||||||
ckpt_node_list = node_list[start:end + 1]
|
ckpt_node_list = node_list[start:end + 1]
|
||||||
@ -347,6 +405,24 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
|
|||||||
emit_node_func(node, ckpt_func)
|
emit_node_func(node, ckpt_func)
|
||||||
ckpt_func[-1] = ' ' + ckpt_func[-1]
|
ckpt_func[-1] = ' ' + ckpt_func[-1]
|
||||||
delete_unused_value_func(node, ckpt_func)
|
delete_unused_value_func(node, ckpt_func)
|
||||||
|
else:
|
||||||
|
# if a node is outside of checkpoint region wants to offload
|
||||||
|
# it's input activation, we will use torch.saved_tensors_hooks
|
||||||
|
# to complete the offload process.
|
||||||
|
if getattr(node, "activation_offload", False):
|
||||||
|
if not is_hook_inserted:
|
||||||
|
pack_hook, unpack_hook = _gen_saved_tensors_hooks()
|
||||||
|
ckpt_func.insert(0, "\n".join([pack_hook, unpack_hook]) + "\n")
|
||||||
|
|
||||||
|
for par in node.all_input_nodes:
|
||||||
|
# annotate the input tensor for pack hook
|
||||||
|
body.append(f"setattr({repr(par)}, 'offload', True)\n")
|
||||||
|
|
||||||
|
body.append(_gen_save_tensors_hooks_context())
|
||||||
|
emit_node_func(node, body)
|
||||||
|
body[-1] = ' ' + body[-1]
|
||||||
|
delete_unused_value_func(node, body)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
emit_node_func(node, body)
|
emit_node_func(node, body)
|
||||||
delete_unused_value_func(node, body)
|
delete_unused_value_func(node, body)
|
||||||
@ -587,10 +663,13 @@ if CODEGEN_AVAILABLE:
|
|||||||
|
|
||||||
# Modified for activation checkpointing
|
# Modified for activation checkpointing
|
||||||
ckpt_func = []
|
ckpt_func = []
|
||||||
if all(not isinstance(getattr(node, "activation_checkpoint", None), list) for node in nodes):
|
|
||||||
emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values)
|
# if any node has a list of labels for activation_checkpoint, we
|
||||||
else:
|
# will use nested type of activation checkpoint codegen
|
||||||
|
if any(isinstance(getattr(node, "activation_checkpoint", None), list) for node in nodes):
|
||||||
emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values)
|
emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values)
|
||||||
|
else:
|
||||||
|
emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values)
|
||||||
|
|
||||||
if len(body) == 0:
|
if len(body) == 0:
|
||||||
# If the Graph has no non-placeholder nodes, no lines for the body
|
# If the Graph has no non-placeholder nodes, no lines for the body
|
||||||
@ -612,7 +691,6 @@ if CODEGEN_AVAILABLE:
|
|||||||
|
|
||||||
# as we need colossalai.utils.checkpoint, we need to import colossalai
|
# as we need colossalai.utils.checkpoint, we need to import colossalai
|
||||||
# in forward function
|
# in forward function
|
||||||
# TODO: Remove inline import
|
|
||||||
prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0])
|
prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0])
|
||||||
prologue = ''.join(ckpt_func) + prologue
|
prologue = ''.join(ckpt_func) + prologue
|
||||||
prologue = prologue
|
prologue = prologue
|
||||||
@ -788,10 +866,13 @@ else:
|
|||||||
|
|
||||||
# Modified for activation checkpointing
|
# Modified for activation checkpointing
|
||||||
ckpt_func = []
|
ckpt_func = []
|
||||||
if all(not isinstance(getattr(node, "activation_checkpoint", None), list) for node in self.nodes):
|
|
||||||
emit_code_with_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values)
|
# if any node has a list of labels for activation_checkpoint, we
|
||||||
else:
|
# will use nested type of activation checkpoint codegen
|
||||||
|
if any(isinstance(getattr(node, "activation_checkpoint", None), list) for node in self.nodes):
|
||||||
emit_code_with_nested_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values)
|
emit_code_with_nested_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values)
|
||||||
|
else:
|
||||||
|
emit_code_with_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values)
|
||||||
|
|
||||||
if len(body) == 0:
|
if len(body) == 0:
|
||||||
# If the Graph has no non-placeholder nodes, no lines for the body
|
# If the Graph has no non-placeholder nodes, no lines for the body
|
||||||
@ -827,7 +908,6 @@ else:
|
|||||||
|
|
||||||
# as we need colossalai.utils.checkpoint, we need to import colossalai
|
# as we need colossalai.utils.checkpoint, we need to import colossalai
|
||||||
# in forward function
|
# in forward function
|
||||||
# TODO: Remove inline import
|
|
||||||
fn_code = f"""
|
fn_code = f"""
|
||||||
{wrap_stmts}
|
{wrap_stmts}
|
||||||
|
|
||||||
|
@ -22,14 +22,20 @@ if COLOGM:
|
|||||||
super().__init__(root, graph, class_name)
|
super().__init__(root, graph, class_name)
|
||||||
|
|
||||||
def bind(self, ckpt_def, globals):
|
def bind(self, ckpt_def, globals):
|
||||||
"""Bind checkpoint functions to ColoGraphModule
|
"""Bind function needed for correctly execute gm forward
|
||||||
We need to bind our checkpoint functions to the GraphModule so
|
|
||||||
that we could correctly use self.checkpoint for GraphModule forward
|
We need to bind checkpoint functions and saved_tensor_hooks functions
|
||||||
|
to gm so that we could correctly execute gm forward
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ckpt_def (_type_): definition before the forward function
|
||||||
|
globals (_type_): global variables
|
||||||
"""
|
"""
|
||||||
|
|
||||||
ckpt_code = "\n".join(ckpt_def)
|
ckpt_code = "\n".join(ckpt_def)
|
||||||
globals_copy = globals.copy()
|
globals_copy = globals.copy()
|
||||||
_exec_with_source(ckpt_code, globals_copy)
|
_exec_with_source(ckpt_code, globals_copy)
|
||||||
func_list = [func for func in globals_copy.keys() if "checkpoint" in func]
|
func_list = [func for func in globals_copy.keys() if "checkpoint" in func or "pack" in func]
|
||||||
for func in func_list:
|
for func in func_list:
|
||||||
tmp_func = globals_copy[func]
|
tmp_func = globals_copy[func]
|
||||||
setattr(self, func, tmp_func.__get__(self, self.__class__))
|
setattr(self, func, tmp_func.__get__(self, self.__class__))
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
from operator import mod
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import pytest
|
import pytest
|
||||||
|
159
tests/test_fx/test_codegen/test_offload_codegen.py
Normal file
159
tests/test_fx/test_codegen/test_offload_codegen.py
Normal file
@ -0,0 +1,159 @@
|
|||||||
|
import copy
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import pytest
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
from torch.fx import GraphModule
|
||||||
|
from colossalai.fx import ColoTracer
|
||||||
|
import colossalai
|
||||||
|
from colossalai.utils import free_port
|
||||||
|
from colossalai.core import global_context as gpc
|
||||||
|
from colossalai.fx.graph_module import ColoGraphModule
|
||||||
|
|
||||||
|
try:
|
||||||
|
from colossalai.fx.codegen import ActivationCheckpointCodeGen
|
||||||
|
with_codegen = True
|
||||||
|
except:
|
||||||
|
# fall back to older pytorch version
|
||||||
|
from colossalai.fx.codegen import python_code_with_activation_checkpoint
|
||||||
|
with_codegen = False
|
||||||
|
|
||||||
|
|
||||||
|
class MyNet(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.linear1 = torch.nn.Linear(4, 4)
|
||||||
|
self.linear2 = torch.nn.Linear(4, 4)
|
||||||
|
self.linear3 = torch.nn.Linear(4, 4)
|
||||||
|
self.linear4 = torch.nn.Linear(4, 4)
|
||||||
|
self.linear5 = torch.nn.Linear(4, 4)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.linear1(x)
|
||||||
|
x = self.linear2(x)
|
||||||
|
x = self.linear3(x)
|
||||||
|
x = self.linear4(x)
|
||||||
|
x = self.linear5(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def _is_all_gradient_close(m: torch.nn.Module, gm: GraphModule) -> bool:
|
||||||
|
for m_p, gm_p in zip(m.parameters(), gm.parameters()):
|
||||||
|
if not torch.allclose(m_p.grad, gm_p.grad):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def _test_fwd_and_bwd(model: torch.nn.Module, gm: ColoGraphModule, data: torch.Tensor):
|
||||||
|
|
||||||
|
# test forward
|
||||||
|
non_fx_out = model(data)
|
||||||
|
fx_out = gm(data)
|
||||||
|
assert torch.equal(non_fx_out, fx_out), "fx_out doesn't comply with original output"
|
||||||
|
|
||||||
|
# test barckward
|
||||||
|
loss0 = non_fx_out.sum()
|
||||||
|
loss0.backward()
|
||||||
|
loss1 = fx_out.sum()
|
||||||
|
loss1.backward()
|
||||||
|
assert _is_all_gradient_close(model, gm), "gm doesn't have the same gradient as original one"
|
||||||
|
|
||||||
|
|
||||||
|
def _run_offload_codegen(rank):
|
||||||
|
# launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly
|
||||||
|
colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl')
|
||||||
|
|
||||||
|
# build model and input
|
||||||
|
model = MyNet().cuda()
|
||||||
|
data = torch.rand(4, 4).cuda()
|
||||||
|
|
||||||
|
# trace the module and replace codegen
|
||||||
|
tracer = ColoTracer(trace_act_ckpt=True)
|
||||||
|
graph = tracer.trace(model)
|
||||||
|
codegen = ActivationCheckpointCodeGen()
|
||||||
|
graph.set_codegen(codegen)
|
||||||
|
|
||||||
|
# annotate the activation offload part
|
||||||
|
# also annotate the activation_checkpoint so we could test both types
|
||||||
|
# of input offload
|
||||||
|
for node in graph.nodes:
|
||||||
|
if node.name == "linear2":
|
||||||
|
setattr(node, "activation_offload", True)
|
||||||
|
if node.name == "linear3":
|
||||||
|
setattr(node, "activation_offload", True)
|
||||||
|
setattr(node, "activation_checkpoint", [0])
|
||||||
|
if node.name == "linear4":
|
||||||
|
setattr(node, "activation_checkpoint", [0])
|
||||||
|
|
||||||
|
gm = ColoGraphModule(copy.deepcopy(model), graph)
|
||||||
|
gm.recompile()
|
||||||
|
print(gm)
|
||||||
|
|
||||||
|
# assert we have all the components
|
||||||
|
code = graph.python_code("self").src
|
||||||
|
assert "def pack_hook(self, x):" in code and \
|
||||||
|
"def unpack_hook(self, packed):" in code and \
|
||||||
|
"setattr(linear1, 'offload', True)" in code and \
|
||||||
|
"with torch.autograd.graph.saved_tensors_hooks(self.pack_hook, self.unpack_hook):" in code and \
|
||||||
|
"colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, linear2, use_reentrant=False)" in code
|
||||||
|
|
||||||
|
_test_fwd_and_bwd(model, gm, data)
|
||||||
|
gpc.destroy()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0')
|
||||||
|
def test_act_ckpt_codegen():
|
||||||
|
mp.spawn(_run_offload_codegen, nprocs=1)
|
||||||
|
|
||||||
|
|
||||||
|
def _run_offload_codegen_torch11(rank):
|
||||||
|
# launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly
|
||||||
|
colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl')
|
||||||
|
|
||||||
|
# build model and input
|
||||||
|
model = MyNet().cuda()
|
||||||
|
data = torch.rand(4, 4).cuda()
|
||||||
|
|
||||||
|
# trace the module and replace codegen
|
||||||
|
tracer = ColoTracer(trace_act_ckpt=True)
|
||||||
|
graph = tracer.trace(model)
|
||||||
|
|
||||||
|
# replace a bound method of an object
|
||||||
|
graph._python_code = python_code_with_activation_checkpoint.__get__(graph)
|
||||||
|
|
||||||
|
# annotate the activation offload part
|
||||||
|
# also annotate the activation_checkpoint so we could test both types
|
||||||
|
# of input offload
|
||||||
|
for node in graph.nodes:
|
||||||
|
if node.name == "linear2":
|
||||||
|
setattr(node, "activation_offload", True)
|
||||||
|
if node.name == "linear3":
|
||||||
|
setattr(node, "activation_offload", True)
|
||||||
|
setattr(node, "activation_checkpoint", [0])
|
||||||
|
if node.name == "linear4":
|
||||||
|
setattr(node, "activation_checkpoint", [0])
|
||||||
|
|
||||||
|
gm = ColoGraphModule(copy.deepcopy(model), graph)
|
||||||
|
gm.recompile()
|
||||||
|
print(gm)
|
||||||
|
|
||||||
|
# assert we have all the components
|
||||||
|
code = graph.python_code("self").src
|
||||||
|
assert "def pack_hook(self, x):" in code and \
|
||||||
|
"def unpack_hook(self, packed):" in code and \
|
||||||
|
"setattr(linear1, 'offload', True)" in code and \
|
||||||
|
"with torch.autograd.graph.saved_tensors_hooks(self.pack_hook, self.unpack_hook):" in code and \
|
||||||
|
"colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, linear2, use_reentrant=False)" in code
|
||||||
|
|
||||||
|
_test_fwd_and_bwd(model, gm, data)
|
||||||
|
gpc.destroy()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="currently torch11 ColoGraphModule is not implemented")
|
||||||
|
def test_act_ckpt_python_code_torch11():
|
||||||
|
mp.spawn(_run_offload_codegen_torch11, nprocs=1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
_run_offload_codegen(0)
|
Loading…
Reference in New Issue
Block a user