[fx] Modify offload codegen (#1618)

* [fx] modify offload codegen

* [fx] remove repeated hook definitions

* [fx] modify offload test
This commit is contained in:
Boyuan Yao
2022-09-23 11:04:52 +08:00
committed by GitHub
parent 9eae855408
commit d6b01feb66
2 changed files with 200 additions and 52 deletions

View File

@@ -23,18 +23,22 @@ class MyNet(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear0 = torch.nn.Linear(4, 4)
self.linear1 = torch.nn.Linear(4, 4)
self.linear2 = torch.nn.Linear(4, 4)
self.linear3 = torch.nn.Linear(4, 4)
self.linear4 = torch.nn.Linear(4, 4)
self.linear5 = torch.nn.Linear(4, 4)
self.linear6 = torch.nn.Linear(4, 4)
def forward(self, x):
x = self.linear0(x)
x = self.linear1(x)
x = self.linear2(x)
x = self.linear3(x)
x = self.linear4(x)
x = self.linear5(x)
x = self.linear6(x)
return x
@@ -78,25 +82,32 @@ def _run_offload_codegen(rank):
# also annotate the activation_checkpoint so we could test both types
# of input offload
for node in graph.nodes:
if node.name == "linear0":
setattr(node, "activation_offload", (0, True, False))
if node.name == "linear1":
setattr(node, "activation_offload", (0, True, False))
if node.name == "linear2":
setattr(node, "activation_offload", True)
if node.name == "linear3":
setattr(node, "activation_offload", True)
setattr(node, "activation_checkpoint", [0])
setattr(node, "activation_offload", (1, True, True))
if node.name == "linear4":
setattr(node, "activation_offload", (2, False, True))
if node.name == "linear5":
setattr(node, "activation_checkpoint", [0])
setattr(node, "activation_offload", True)
gm = ColoGraphModule(copy.deepcopy(model), graph)
gm.recompile()
print(gm)
# assert we have all the components
code = graph.python_code("self").src
assert "def pack_hook(self, x):" in code and \
assert "def pack_hook_input(self, x):" in code and \
"def unpack_hook(self, packed):" in code and \
"setattr(linear1, 'offload', True)" in code and \
"with torch.autograd.graph.saved_tensors_hooks(self.pack_hook, self.unpack_hook):" in code and \
"colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, linear2, use_reentrant=False)" in code
"def pack_hook_no_input(self, x):" in code and \
"setattr(x, 'offload', True)" in code and \
"setattr(linear3, 'offload', False)" in code and \
"with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_input, self.unpack_hook):" in code and \
"with torch.autograd.graph.save_on_cpu(pin_memory=True):" in code and \
"with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_no_input, self.unpack_hook):" in code and \
"colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, linear4, use_reentrant=False)" in code
_test_fwd_and_bwd(model, gm, data)
gpc.destroy()
@@ -126,25 +137,32 @@ def _run_offload_codegen_torch11(rank):
# also annotate the activation_checkpoint so we could test both types
# of input offload
for node in graph.nodes:
if node.name == "linear0":
setattr(node, "activation_offload", (0, True, False))
if node.name == "linear1":
setattr(node, "activation_offload", (0, True, False))
if node.name == "linear2":
setattr(node, "activation_offload", True)
if node.name == "linear3":
setattr(node, "activation_offload", True)
setattr(node, "activation_checkpoint", [0])
setattr(node, "activation_offload", (1, True, True))
if node.name == "linear4":
setattr(node, "activation_offload", (2, False, True))
if node.name == "linear5":
setattr(node, "activation_checkpoint", [0])
setattr(node, "activation_offload", True)
gm = ColoGraphModule(copy.deepcopy(model), graph)
gm.recompile()
print(gm)
# assert we have all the components
code = graph.python_code("self").src
assert "def pack_hook(self, x):" in code and \
assert "def pack_hook_input(self, x):" in code and \
"def unpack_hook(self, packed):" in code and \
"setattr(linear1, 'offload', True)" in code and \
"with torch.autograd.graph.saved_tensors_hooks(self.pack_hook, self.unpack_hook):" in code and \
"colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, linear2, use_reentrant=False)" in code
"def pack_hook_no_input(self, x):" in code and \
"setattr(x, 'offload', True)" in code and \
"setattr(linear3, 'offload', False)" in code and \
"with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_input, self.unpack_hook):" in code and \
"with torch.autograd.graph.save_on_cpu(pin_memory=True):" in code and \
"with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_no_input, self.unpack_hook):" in code and \
"colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, linear4, use_reentrant=False)" in code
_test_fwd_and_bwd(model, gm, data)
gpc.destroy()