mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 18:19:58 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -11,15 +11,16 @@ from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
try:
|
||||
from colossalai.fx.codegen import ActivationCheckpointCodeGen
|
||||
|
||||
with_codegen = True
|
||||
except:
|
||||
# fall back to older pytorch version
|
||||
from colossalai.fx.codegen import python_code_with_activation_checkpoint
|
||||
|
||||
with_codegen = False
|
||||
|
||||
|
||||
class MLP(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear1 = torch.nn.Linear(4, 4)
|
||||
@@ -30,7 +31,6 @@ class MLP(torch.nn.Module):
|
||||
|
||||
|
||||
class relu(torch.nn.Module):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.relu = torch.nn.ReLU(inplace=True)
|
||||
@@ -40,7 +40,6 @@ class relu(torch.nn.Module):
|
||||
|
||||
|
||||
class MyModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.mlp1 = MLP()
|
||||
@@ -65,7 +64,7 @@ class MyModule(torch.nn.Module):
|
||||
|
||||
def _run_act_ckpt_codegen(rank, world_size, port):
|
||||
# launch colossalai to make sure we could execute colossalai.utils.checkpoint currently
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
|
||||
# build model and run forward
|
||||
model = MyModule()
|
||||
@@ -87,26 +86,31 @@ def _run_act_ckpt_codegen(rank, world_size, port):
|
||||
|
||||
# check ops are annotated with ckpt
|
||||
# also annotate the selected node for offloading
|
||||
ckpt_nodes = ['mlp1_linear1', 'mlp1_linear2', 'relu_relu', 'relu']
|
||||
offload_starts = ['mlp1_linear1']
|
||||
ckpt_nodes = ["mlp1_linear1", "mlp1_linear2", "relu_relu", "relu"]
|
||||
offload_starts = ["mlp1_linear1"]
|
||||
for node in graph.nodes:
|
||||
if node.name in ckpt_nodes:
|
||||
assert 'activation_checkpoint' in node.meta
|
||||
assert "activation_checkpoint" in node.meta
|
||||
|
||||
# annotate the selected node for offload
|
||||
if node.name in offload_starts:
|
||||
node.meta['activation_offload'] = True
|
||||
node.meta["activation_offload"] = True
|
||||
|
||||
gm = ColoGraphModule(model, graph)
|
||||
gm.recompile()
|
||||
|
||||
# assert checkpoint function will be generated and
|
||||
# the offload option is correct
|
||||
code = graph.python_code('self').src
|
||||
assert 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, x, use_reentrant=False)' in code and \
|
||||
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_1, False, x, use_reentrant=False)' in code and \
|
||||
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_2, False, y, use_reentrant=False)' in code and \
|
||||
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_3, False, y, relu, use_reentrant=True)' in code
|
||||
code = graph.python_code("self").src
|
||||
assert (
|
||||
"colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, x, use_reentrant=False)" in code
|
||||
and "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_1, False, x, use_reentrant=False)"
|
||||
in code
|
||||
and "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_2, False, y, use_reentrant=False)"
|
||||
in code
|
||||
and "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_3, False, y, relu, use_reentrant=True)"
|
||||
in code
|
||||
)
|
||||
|
||||
# recompile and verify the outputs are consistent
|
||||
fx_out = gm(data1, data2)
|
||||
@@ -115,7 +119,7 @@ def _run_act_ckpt_codegen(rank, world_size, port):
|
||||
gpc.destroy()
|
||||
|
||||
|
||||
@pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0')
|
||||
@pytest.mark.skipif(not with_codegen, reason="torch version is lower than 1.12.0")
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_act_ckpt_codegen():
|
||||
spawn(_run_act_ckpt_codegen, 1)
|
||||
@@ -123,7 +127,7 @@ def test_act_ckpt_codegen():
|
||||
|
||||
def _run_act_ckpt_python_code_torch11(rank, world_size, port):
|
||||
# launch colossalai to make sure we could execute colossalai.utils.checkpoint currently
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
|
||||
# build model and run forward
|
||||
model = MyModule()
|
||||
@@ -144,25 +148,30 @@ def _run_act_ckpt_python_code_torch11(rank, world_size, port):
|
||||
graph._python_code = python_code_with_activation_checkpoint.__get__(graph)
|
||||
|
||||
# check ops are annotated with ckpt
|
||||
ckpt_nodes = ['mlp1_linear1', 'mlp1_linear2', 'relu_relu', 'relu']
|
||||
offload_starts = ['mlp1_linear1']
|
||||
ckpt_nodes = ["mlp1_linear1", "mlp1_linear2", "relu_relu", "relu"]
|
||||
offload_starts = ["mlp1_linear1"]
|
||||
for node in graph.nodes:
|
||||
if node.name in ckpt_nodes:
|
||||
assert 'activation_checkpoint' in node.meta
|
||||
assert "activation_checkpoint" in node.meta
|
||||
|
||||
# annotate the selected node for offload
|
||||
if node.name in offload_starts:
|
||||
node.meta['activation_offload'] = True
|
||||
node.meta["activation_offload"] = True
|
||||
|
||||
gm = ColoGraphModule(model, graph)
|
||||
gm.recompile()
|
||||
# assert checkpoint function will be generated and
|
||||
# the offload option is correct
|
||||
code = graph.python_code('self').src
|
||||
assert 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, x, use_reentrant=False)' in code and \
|
||||
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_1, False, x, use_reentrant=False)' in code and \
|
||||
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_2, False, y, use_reentrant=False)' in code and \
|
||||
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_3, False, y, relu, use_reentrant=True)' in code
|
||||
code = graph.python_code("self").src
|
||||
assert (
|
||||
"colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, x, use_reentrant=False)" in code
|
||||
and "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_1, False, x, use_reentrant=False)"
|
||||
in code
|
||||
and "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_2, False, y, use_reentrant=False)"
|
||||
in code
|
||||
and "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_3, False, y, relu, use_reentrant=True)"
|
||||
in code
|
||||
)
|
||||
|
||||
# recompile and verify the outputs are consistent
|
||||
fx_out = gm(data1, data2)
|
||||
@@ -171,12 +180,12 @@ def _run_act_ckpt_python_code_torch11(rank, world_size, port):
|
||||
gpc.destroy()
|
||||
|
||||
|
||||
@pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0')
|
||||
@pytest.mark.skipif(with_codegen, reason="torch version is equal to or higher than 1.12.0")
|
||||
@pytest.mark.skip(reason="currently torch11 ColoGraphModule is not done")
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_act_ckpt_python_code_torch11():
|
||||
spawn(_run_act_ckpt_python_code_torch11, 1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
_run_act_ckpt_codegen(rank=0)
|
||||
|
@@ -9,15 +9,14 @@ from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
try:
|
||||
from colossalai.fx.codegen import ActivationCheckpointCodeGen
|
||||
|
||||
with_codegen = True
|
||||
except:
|
||||
# fall back to older pytorch version
|
||||
from colossalai.fx.codegen import python_code_with_activation_checkpoint
|
||||
with_codegen = False
|
||||
|
||||
|
||||
class MyModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear1 = torch.nn.Linear(4, 4)
|
||||
@@ -33,7 +32,7 @@ class MyModule(torch.nn.Module):
|
||||
|
||||
def _run_act_ckpt_codegen(rank, world_size, port):
|
||||
# launch colossalai to make sure we could execute colossalai.utils.checkpoint currently
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
|
||||
# build model and run forward
|
||||
model = MyModule()
|
||||
@@ -54,27 +53,34 @@ def _run_act_ckpt_codegen(rank, world_size, port):
|
||||
# annotate nested checkpoint
|
||||
for node in graph.nodes:
|
||||
if node.name == "linear1":
|
||||
node.meta['activation_checkpoint'] = [0, 0, 0]
|
||||
node.meta["activation_checkpoint"] = [0, 0, 0]
|
||||
continue
|
||||
if node.name == "linear2":
|
||||
node.meta['activation_checkpoint'] = [0, 0, None]
|
||||
node.meta["activation_checkpoint"] = [0, 0, None]
|
||||
if node.name == "linear3":
|
||||
node.meta['activation_checkpoint'] = [0, 0, 1]
|
||||
node.meta["activation_checkpoint"] = [0, 0, 1]
|
||||
if node.name == "linear4":
|
||||
node.meta['activation_checkpoint'] = [0, 1, None]
|
||||
node.meta["activation_checkpoint"] = [0, 1, None]
|
||||
if node.name == "linear5":
|
||||
node.meta['activation_checkpoint'] = 1
|
||||
node.meta["activation_checkpoint"] = 1
|
||||
gm = ColoGraphModule(model, graph)
|
||||
gm.recompile()
|
||||
|
||||
# assert checkpoint function will be generated and
|
||||
code = graph.python_code('self').src
|
||||
assert 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0, False, x, use_reentrant=False)' in code and \
|
||||
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_1, False, linear3, use_reentrant=False)' in code and \
|
||||
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0_0, False, x, use_reentrant=False)' in code and \
|
||||
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0_1, False, linear2, use_reentrant=False)' in code and \
|
||||
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, False, x, use_reentrant=False)' in code and \
|
||||
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_1, False, linear4, use_reentrant=False)' in code
|
||||
code = graph.python_code("self").src
|
||||
assert (
|
||||
"colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0, False, x, use_reentrant=False)" in code
|
||||
and "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_1, False, linear3, use_reentrant=False)"
|
||||
in code
|
||||
and "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0_0, False, x, use_reentrant=False)"
|
||||
in code
|
||||
and "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0_1, False, linear2, use_reentrant=False)"
|
||||
in code
|
||||
and "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, False, x, use_reentrant=False)"
|
||||
in code
|
||||
and "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_1, False, linear4, use_reentrant=False)"
|
||||
in code
|
||||
)
|
||||
|
||||
# recompile and verify the outputs are consistent
|
||||
fx_out = gm(data1)
|
||||
@@ -83,14 +89,14 @@ def _run_act_ckpt_codegen(rank, world_size, port):
|
||||
gpc.destroy()
|
||||
|
||||
|
||||
@pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0')
|
||||
@pytest.mark.skipif(not with_codegen, reason="torch version is lower than 1.12.0")
|
||||
def test_act_ckpt_codegen():
|
||||
spawn(_run_act_ckpt_codegen, 1)
|
||||
|
||||
|
||||
def _run_act_ckpt_python_code_torch11(rank, world_size, port):
|
||||
# launch colossalai to make sure we could execute colossalai.utils.checkpoint currently
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
|
||||
# build model and run forward
|
||||
model = MyModule()
|
||||
@@ -111,27 +117,34 @@ def _run_act_ckpt_python_code_torch11(rank, world_size, port):
|
||||
# annotate nested checkpoint
|
||||
for node in graph.nodes:
|
||||
if node.name == "linear1":
|
||||
node.meta['activation_checkpoint'] = [0, 0, 0]
|
||||
node.meta["activation_checkpoint"] = [0, 0, 0]
|
||||
continue
|
||||
if node.name == "linear2":
|
||||
node.meta['activation_checkpoint'] = [0, 0, None]
|
||||
node.meta["activation_checkpoint"] = [0, 0, None]
|
||||
if node.name == "linear3":
|
||||
node.meta['activation_checkpoint'] = [0, 0, 1]
|
||||
node.meta["activation_checkpoint"] = [0, 0, 1]
|
||||
if node.name == "linear4":
|
||||
node.meta['activation_checkpoint'] = [0, 1, None]
|
||||
node.meta["activation_checkpoint"] = [0, 1, None]
|
||||
if node.name == "linear5":
|
||||
node.meta['activation_checkpoint'] = 1
|
||||
node.meta["activation_checkpoint"] = 1
|
||||
gm = ColoGraphModule(model, graph)
|
||||
gm.recompile()
|
||||
|
||||
# assert checkpoint function will be generated and
|
||||
code = graph.python_code('self').src
|
||||
assert 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0, False, x, use_reentrant=False)' in code and \
|
||||
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_1, False, linear3, use_reentrant=False)' in code and \
|
||||
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0_0, False, x, use_reentrant=False)' in code and \
|
||||
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0_1, False, linear2, use_reentrant=False)' in code and \
|
||||
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, False, x, use_reentrant=False)' in code and \
|
||||
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_1, False, linear4, use_reentrant=False)' in code
|
||||
code = graph.python_code("self").src
|
||||
assert (
|
||||
"colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0, False, x, use_reentrant=False)" in code
|
||||
and "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_1, False, linear3, use_reentrant=False)"
|
||||
in code
|
||||
and "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0_0, False, x, use_reentrant=False)"
|
||||
in code
|
||||
and "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0_1, False, linear2, use_reentrant=False)"
|
||||
in code
|
||||
and "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, False, x, use_reentrant=False)"
|
||||
in code
|
||||
and "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_1, False, linear4, use_reentrant=False)"
|
||||
in code
|
||||
)
|
||||
|
||||
# recompile and verify the outputs are consistent
|
||||
fx_out = gm(data1)
|
||||
@@ -140,12 +153,12 @@ def _run_act_ckpt_python_code_torch11(rank, world_size, port):
|
||||
gpc.destroy()
|
||||
|
||||
|
||||
@pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0')
|
||||
@pytest.mark.skipif(with_codegen, reason="torch version is equal to or higher than 1.12.0")
|
||||
@pytest.mark.skip(reason="currently torch11 ColoGraphModule is not done")
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_act_ckpt_python_code_torch11():
|
||||
spawn(_run_act_ckpt_python_code_torch11, 1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
_run_act_ckpt_codegen(rank=0)
|
||||
|
@@ -12,15 +12,16 @@ from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
try:
|
||||
from colossalai.fx.codegen import ActivationCheckpointCodeGen
|
||||
|
||||
with_codegen = True
|
||||
except:
|
||||
# fall back to older pytorch version
|
||||
from colossalai.fx.codegen import python_code_with_activation_checkpoint
|
||||
|
||||
with_codegen = False
|
||||
|
||||
|
||||
class MyNet(torch.nn.Module):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.linear0 = torch.nn.Linear(4, 4)
|
||||
@@ -50,7 +51,6 @@ def _is_all_gradient_close(m: torch.nn.Module, gm: GraphModule) -> bool:
|
||||
|
||||
|
||||
def _test_fwd_and_bwd(model: torch.nn.Module, gm: ColoGraphModule, data: torch.Tensor):
|
||||
|
||||
# test forward
|
||||
non_fx_out = model(data)
|
||||
fx_out = gm(data)
|
||||
@@ -66,7 +66,7 @@ def _test_fwd_and_bwd(model: torch.nn.Module, gm: ColoGraphModule, data: torch.T
|
||||
|
||||
def _run_offload_codegen(rank, world_size, port):
|
||||
# launch colossalai to make sure we could execute colossalai.utils.checkpoint currently
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
|
||||
# build model and input
|
||||
model = MyNet().cuda()
|
||||
@@ -83,37 +83,40 @@ def _run_offload_codegen(rank, world_size, port):
|
||||
# of input offload
|
||||
for node in graph.nodes:
|
||||
if node.name == "linear0":
|
||||
node.meta['activation_offload'] = [0, True, False]
|
||||
node.meta["activation_offload"] = [0, True, False]
|
||||
if node.name == "linear1":
|
||||
node.meta['activation_offload'] = [0, True, False]
|
||||
node.meta["activation_offload"] = [0, True, False]
|
||||
if node.name == "linear2":
|
||||
node.meta['activation_offload'] = [1, True, True]
|
||||
node.meta["activation_offload"] = [1, True, True]
|
||||
if node.name == "linear4":
|
||||
node.meta['activation_offload'] = [2, False, True]
|
||||
node.meta["activation_offload"] = [2, False, True]
|
||||
if node.name == "linear5":
|
||||
node.meta['activation_checkpoint'] = [0]
|
||||
node.meta['activation_offload'] = True
|
||||
node.meta["activation_checkpoint"] = [0]
|
||||
node.meta["activation_offload"] = True
|
||||
|
||||
gm = ColoGraphModule(copy.deepcopy(model), graph)
|
||||
gm.recompile()
|
||||
|
||||
# assert we have all the components
|
||||
code = graph.python_code("self").src
|
||||
assert "def pack_hook_input(self, x):" in code and \
|
||||
"def unpack_hook(self, packed):" in code and \
|
||||
"def pack_hook_no_input(self, x):" in code and \
|
||||
"setattr(x, 'offload', True)" in code and \
|
||||
"setattr(linear3, 'offload', False)" in code and \
|
||||
"with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_input, self.unpack_hook):" in code and \
|
||||
"with torch.autograd.graph.save_on_cpu(pin_memory=True):" in code and \
|
||||
"with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_no_input, self.unpack_hook):" in code and \
|
||||
"colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, linear4, use_reentrant=False)" in code
|
||||
assert (
|
||||
"def pack_hook_input(self, x):" in code
|
||||
and "def unpack_hook(self, packed):" in code
|
||||
and "def pack_hook_no_input(self, x):" in code
|
||||
and "setattr(x, 'offload', True)" in code
|
||||
and "setattr(linear3, 'offload', False)" in code
|
||||
and "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_input, self.unpack_hook):" in code
|
||||
and "with torch.autograd.graph.save_on_cpu(pin_memory=True):" in code
|
||||
and "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_no_input, self.unpack_hook):" in code
|
||||
and "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, linear4, use_reentrant=False)"
|
||||
in code
|
||||
)
|
||||
|
||||
_test_fwd_and_bwd(model, gm, data)
|
||||
gpc.destroy()
|
||||
|
||||
|
||||
@pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0')
|
||||
@pytest.mark.skipif(not with_codegen, reason="torch version is lower than 1.12.0")
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_act_ckpt_codegen():
|
||||
spawn(_run_offload_codegen, 1)
|
||||
@@ -121,7 +124,7 @@ def test_act_ckpt_codegen():
|
||||
|
||||
def _run_offload_codegen_torch11(rank, world_size, port):
|
||||
# launch colossalai to make sure we could execute colossalai.utils.checkpoint currently
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
|
||||
# build model and input
|
||||
model = MyNet().cuda()
|
||||
@@ -139,31 +142,34 @@ def _run_offload_codegen_torch11(rank, world_size, port):
|
||||
# of input offload
|
||||
for node in graph.nodes:
|
||||
if node.name == "linear0":
|
||||
node.meta['activation_offload'] = [0, True, False]
|
||||
node.meta["activation_offload"] = [0, True, False]
|
||||
if node.name == "linear1":
|
||||
node.meta['activation_offload'] = [0, True, False]
|
||||
node.meta["activation_offload"] = [0, True, False]
|
||||
if node.name == "linear2":
|
||||
node.meta['activation_offload'] = [1, True, True]
|
||||
node.meta["activation_offload"] = [1, True, True]
|
||||
if node.name == "linear4":
|
||||
node.meta['activation_offload'] = [2, False, True]
|
||||
node.meta["activation_offload"] = [2, False, True]
|
||||
if node.name == "linear5":
|
||||
node.meta['activation_checkpoint'] = [0]
|
||||
node.meta['activation_offload'] = True
|
||||
node.meta["activation_checkpoint"] = [0]
|
||||
node.meta["activation_offload"] = True
|
||||
|
||||
gm = ColoGraphModule(copy.deepcopy(model), graph)
|
||||
gm.recompile()
|
||||
|
||||
# assert we have all the components
|
||||
code = graph.python_code("self").src
|
||||
assert "def pack_hook_input(self, x):" in code and \
|
||||
"def unpack_hook(self, packed):" in code and \
|
||||
"def pack_hook_no_input(self, x):" in code and \
|
||||
"setattr(x, 'offload', True)" in code and \
|
||||
"setattr(linear3, 'offload', False)" in code and \
|
||||
"with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_input, self.unpack_hook):" in code and \
|
||||
"with torch.autograd.graph.save_on_cpu(pin_memory=True):" in code and \
|
||||
"with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_no_input, self.unpack_hook):" in code and \
|
||||
"colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, linear4, use_reentrant=False)" in code
|
||||
assert (
|
||||
"def pack_hook_input(self, x):" in code
|
||||
and "def unpack_hook(self, packed):" in code
|
||||
and "def pack_hook_no_input(self, x):" in code
|
||||
and "setattr(x, 'offload', True)" in code
|
||||
and "setattr(linear3, 'offload', False)" in code
|
||||
and "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_input, self.unpack_hook):" in code
|
||||
and "with torch.autograd.graph.save_on_cpu(pin_memory=True):" in code
|
||||
and "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_no_input, self.unpack_hook):" in code
|
||||
and "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, linear4, use_reentrant=False)"
|
||||
in code
|
||||
)
|
||||
|
||||
_test_fwd_and_bwd(model, gm, data)
|
||||
gpc.destroy()
|
||||
|
Reference in New Issue
Block a user