mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-13 13:45:51 +00:00
[fx] added activation checkpointing annotation (#1349)
* [fx] added activation checkpointing annotation * polish code * polish code
This commit is contained in:
parent
051592c64e
commit
05fae1fd56
@ -8,11 +8,12 @@ import enum
|
||||
import inspect
|
||||
import functools
|
||||
import operator
|
||||
from contextlib import contextmanager
|
||||
from colossalai.fx.tracer.meta_patch import meta_patched_module
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch.fx import Tracer
|
||||
from torch.fx import Tracer, Node
|
||||
from torch.fx.graph import Graph
|
||||
from torch.fx.proxy import Proxy, ParameterProxy
|
||||
from ..proxy import ColoProxy
|
||||
@ -55,11 +56,17 @@ class ColoTracer(Tracer):
|
||||
graph = tracer.trace(model, concrete_args={'y': torch.rand(4, 10)}, meta_args={'x': torch.rand(4, 10, device='meta')})
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
def __init__(self, trace_act_ckpt: bool = False, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.tracer_type = TracerType.META
|
||||
self.proxy_cls = ColoProxy
|
||||
|
||||
# whether the tracer will record the usage of torch.utils.checkpoint
|
||||
self.trace_act_ckpt = trace_act_ckpt
|
||||
# whether the current tracing occurs within the activation checkpoint functions
|
||||
self.inside_torch_checkpoint_func = False
|
||||
self.act_ckpt_region_count = 0
|
||||
|
||||
# Feature flag for proxying accesses to buffer values
|
||||
proxy_buffer_attributes: bool = True
|
||||
|
||||
@ -297,7 +304,10 @@ class ColoTracer(Tracer):
|
||||
self.orig_torch_tensor_methods = [val[1] for val in self.patched_torch_tensor_methods.values()]
|
||||
|
||||
try:
|
||||
# to track the usage of torch.utils.checkpoint
|
||||
with self.trace_activation_checkpoint(enabled=self.trace_act_ckpt):
|
||||
self.graph = super().trace(root, concrete_args=concrete_args)
|
||||
|
||||
finally:
|
||||
# recover the patched methods
|
||||
for name, (_, orig) in self.patched_torch_tensor_methods.items():
|
||||
@ -338,6 +348,43 @@ class ColoTracer(Tracer):
|
||||
|
||||
return self.graph
|
||||
|
||||
@contextmanager
|
||||
def trace_activation_checkpoint(self, enabled: bool):
|
||||
if enabled:
|
||||
orig_ckpt_func = torch.utils.checkpoint.CheckpointFunction
|
||||
|
||||
class PatchedCheckpointFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, run_function, preserve_rng_state, *args):
|
||||
# signal that the current tracing occurs within activaton checkpoint part
|
||||
self.inside_torch_checkpoint_func = True
|
||||
out = run_function(*args)
|
||||
self.inside_torch_checkpoint_func = False
|
||||
self.act_ckpt_region_count += 1
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Any, *grad_outputs: Any) -> Any:
|
||||
raise NotImplementedError(
|
||||
"We do not implement the backward pass as we only trace the forward pass.")
|
||||
|
||||
# override the checkpoint function
|
||||
torch.utils.checkpoint.CheckpointFunction = PatchedCheckpointFunction
|
||||
yield
|
||||
|
||||
if enabled:
|
||||
# recover the checkpoint function upon exit
|
||||
torch.utils.checkpoint.CheckpointFunction = orig_ckpt_func
|
||||
|
||||
def create_node(self, *args, **kwargs) -> Node:
|
||||
node = super().create_node(*args, **kwargs)
|
||||
|
||||
if self.inside_torch_checkpoint_func:
|
||||
# annotate the activation checkpoint module
|
||||
setattr(node, 'activation_checkpoint', self.act_ckpt_region_count)
|
||||
return node
|
||||
|
||||
|
||||
def wrap_tensor_constructor_method(target):
|
||||
|
||||
@ -367,7 +414,7 @@ def wrap_tensor_constructor_method(target):
|
||||
colo_proxy = proxy.tracer.create_proxy("call_function", target, args, kwargs)
|
||||
if not isinstance(colo_proxy, ColoProxy):
|
||||
meta_out = compute_meta_data_for_functions_proxy(target, args, kwargs)
|
||||
colo_proxy = ColoProxy(fx_proxy.node)
|
||||
colo_proxy = ColoProxy(proxy.node)
|
||||
colo_proxy.meta_data = meta_out
|
||||
return colo_proxy
|
||||
else:
|
||||
|
@ -0,0 +1,62 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from colossalai.fx import ColoTracer
|
||||
from torch.fx import GraphModule
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
|
||||
class MLP(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear1 = torch.nn.Linear(4, 4)
|
||||
self.linear2 = torch.nn.Linear(4, 4)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear1(x)
|
||||
x = self.linear2(x)
|
||||
return x
|
||||
|
||||
|
||||
# Simple module for demonstration
|
||||
class MyModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.mlp_1 = MLP()
|
||||
self.mlp_2 = MLP()
|
||||
self.output = torch.nn.Linear(4, 4)
|
||||
|
||||
def forward(self, x):
|
||||
x = checkpoint(self.mlp_1, x)
|
||||
x = checkpoint(self.mlp_2, x)
|
||||
x = self.output(x)
|
||||
return x
|
||||
|
||||
|
||||
def test_activation_checkpoint_annotation():
|
||||
module = MyModule()
|
||||
|
||||
# test tracing with activation checkpoint
|
||||
tracer = ColoTracer(trace_act_ckpt=True)
|
||||
graph = tracer.trace(module)
|
||||
gm = GraphModule(module, graph)
|
||||
|
||||
for node in gm.graph.nodes:
|
||||
if node.name in ['mlp_1_linear1', 'mlp_1_linear2']:
|
||||
assert getattr(node, 'activation_checkpoint', -1) == 0
|
||||
|
||||
for node in gm.graph.nodes:
|
||||
if node.name in ['mlp_2_linear1', 'mlp_2_linear2']:
|
||||
assert getattr(node, 'activation_checkpoint', -1) == 1
|
||||
|
||||
tracer = ColoTracer(trace_act_ckpt=False)
|
||||
graph = tracer.trace(module)
|
||||
gm = GraphModule(module, graph)
|
||||
|
||||
for node in gm.graph.nodes:
|
||||
assert not hasattr(node, 'activation_checkpoint')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_activation_checkpoint_annotation()
|
Loading…
Reference in New Issue
Block a user