mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
[fx] Modify offload codegen (#1618)
* [fx] modify offload codegen * [fx] remove repeated hook definitions * [fx] modify offload test
This commit is contained in:
@@ -23,18 +23,22 @@ class MyNet(torch.nn.Module):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.linear0 = torch.nn.Linear(4, 4)
|
||||
self.linear1 = torch.nn.Linear(4, 4)
|
||||
self.linear2 = torch.nn.Linear(4, 4)
|
||||
self.linear3 = torch.nn.Linear(4, 4)
|
||||
self.linear4 = torch.nn.Linear(4, 4)
|
||||
self.linear5 = torch.nn.Linear(4, 4)
|
||||
self.linear6 = torch.nn.Linear(4, 4)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear0(x)
|
||||
x = self.linear1(x)
|
||||
x = self.linear2(x)
|
||||
x = self.linear3(x)
|
||||
x = self.linear4(x)
|
||||
x = self.linear5(x)
|
||||
x = self.linear6(x)
|
||||
return x
|
||||
|
||||
|
||||
@@ -78,25 +82,32 @@ def _run_offload_codegen(rank):
|
||||
# also annotate the activation_checkpoint so we could test both types
|
||||
# of input offload
|
||||
for node in graph.nodes:
|
||||
if node.name == "linear0":
|
||||
setattr(node, "activation_offload", (0, True, False))
|
||||
if node.name == "linear1":
|
||||
setattr(node, "activation_offload", (0, True, False))
|
||||
if node.name == "linear2":
|
||||
setattr(node, "activation_offload", True)
|
||||
if node.name == "linear3":
|
||||
setattr(node, "activation_offload", True)
|
||||
setattr(node, "activation_checkpoint", [0])
|
||||
setattr(node, "activation_offload", (1, True, True))
|
||||
if node.name == "linear4":
|
||||
setattr(node, "activation_offload", (2, False, True))
|
||||
if node.name == "linear5":
|
||||
setattr(node, "activation_checkpoint", [0])
|
||||
setattr(node, "activation_offload", True)
|
||||
|
||||
gm = ColoGraphModule(copy.deepcopy(model), graph)
|
||||
gm.recompile()
|
||||
print(gm)
|
||||
|
||||
# assert we have all the components
|
||||
code = graph.python_code("self").src
|
||||
assert "def pack_hook(self, x):" in code and \
|
||||
assert "def pack_hook_input(self, x):" in code and \
|
||||
"def unpack_hook(self, packed):" in code and \
|
||||
"setattr(linear1, 'offload', True)" in code and \
|
||||
"with torch.autograd.graph.saved_tensors_hooks(self.pack_hook, self.unpack_hook):" in code and \
|
||||
"colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, linear2, use_reentrant=False)" in code
|
||||
"def pack_hook_no_input(self, x):" in code and \
|
||||
"setattr(x, 'offload', True)" in code and \
|
||||
"setattr(linear3, 'offload', False)" in code and \
|
||||
"with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_input, self.unpack_hook):" in code and \
|
||||
"with torch.autograd.graph.save_on_cpu(pin_memory=True):" in code and \
|
||||
"with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_no_input, self.unpack_hook):" in code and \
|
||||
"colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, linear4, use_reentrant=False)" in code
|
||||
|
||||
_test_fwd_and_bwd(model, gm, data)
|
||||
gpc.destroy()
|
||||
@@ -126,25 +137,32 @@ def _run_offload_codegen_torch11(rank):
|
||||
# also annotate the activation_checkpoint so we could test both types
|
||||
# of input offload
|
||||
for node in graph.nodes:
|
||||
if node.name == "linear0":
|
||||
setattr(node, "activation_offload", (0, True, False))
|
||||
if node.name == "linear1":
|
||||
setattr(node, "activation_offload", (0, True, False))
|
||||
if node.name == "linear2":
|
||||
setattr(node, "activation_offload", True)
|
||||
if node.name == "linear3":
|
||||
setattr(node, "activation_offload", True)
|
||||
setattr(node, "activation_checkpoint", [0])
|
||||
setattr(node, "activation_offload", (1, True, True))
|
||||
if node.name == "linear4":
|
||||
setattr(node, "activation_offload", (2, False, True))
|
||||
if node.name == "linear5":
|
||||
setattr(node, "activation_checkpoint", [0])
|
||||
setattr(node, "activation_offload", True)
|
||||
|
||||
gm = ColoGraphModule(copy.deepcopy(model), graph)
|
||||
gm.recompile()
|
||||
print(gm)
|
||||
|
||||
# assert we have all the components
|
||||
code = graph.python_code("self").src
|
||||
assert "def pack_hook(self, x):" in code and \
|
||||
assert "def pack_hook_input(self, x):" in code and \
|
||||
"def unpack_hook(self, packed):" in code and \
|
||||
"setattr(linear1, 'offload', True)" in code and \
|
||||
"with torch.autograd.graph.saved_tensors_hooks(self.pack_hook, self.unpack_hook):" in code and \
|
||||
"colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, linear2, use_reentrant=False)" in code
|
||||
"def pack_hook_no_input(self, x):" in code and \
|
||||
"setattr(x, 'offload', True)" in code and \
|
||||
"setattr(linear3, 'offload', False)" in code and \
|
||||
"with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_input, self.unpack_hook):" in code and \
|
||||
"with torch.autograd.graph.save_on_cpu(pin_memory=True):" in code and \
|
||||
"with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_no_input, self.unpack_hook):" in code and \
|
||||
"colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, linear4, use_reentrant=False)" in code
|
||||
|
||||
_test_fwd_and_bwd(model, gm, data)
|
||||
gpc.destroy()
|
||||
|
Reference in New Issue
Block a user