[fx] Fix ckpt functions' definitions in forward (#1476)

* [fx] fix defining ckpt functions inside forward

* [fx] Modify activation checkpoint codegen and add ColoGraphModule

* [fx] some modification

* some modifications

* some modifications

* some modifications

* some modifications

* some code modifications
This commit is contained in:
Boyuan Yao 2022-08-22 16:59:54 +08:00 committed by GitHub
parent bb5f5289e0
commit 1f2e547f7a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 217 additions and 39 deletions

View File

@ -1,12 +1,13 @@
import colossalai
import torch import torch
from typing import List, Callable, Any, Tuple, Dict from typing import List, Callable, Any, Tuple, Dict
try: try:
from torch.fx.node import Node, Argument, map_arg, _type_repr, _get_qualified_name from torch.fx.node import Node, Argument, map_arg, _type_repr, _get_qualified_name
from torch.fx.graph import _Namespace, PythonCode, _custom_builtins, _is_from_torch, _format_target, magic_methods, CodeGen, _origin_type_map, inplace_methods from torch.fx.graph import _Namespace, PythonCode, _custom_builtins, _is_from_torch, _format_target, magic_methods, CodeGen, _origin_type_map, inplace_methods, _CustomBuiltin
CODEGEN_AVAILABLE = True CODEGEN_AVAILABLE = True
except: except:
from torch.fx.graph import _Namespace, PythonCode, _custom_builtins, _is_from_torch, _format_target, magic_methods, _origin_type_map, _format_args from torch.fx.graph import _Namespace, PythonCode, _custom_builtins, _is_from_torch, _format_target, magic_methods, _origin_type_map, _format_args, _CustomBuiltin
from torch.fx.node import Node, Argument, map_arg, _type_repr, _get_qualified_name from torch.fx.node import Node, Argument, map_arg, _type_repr, _get_qualified_name
CODEGEN_AVAILABLE = False CODEGEN_AVAILABLE = False
@ -89,7 +90,7 @@ def _gen_ckpt_fn_def(label, free_vars: List[str]) -> str:
""" """
Generate the checkpoint function definition Generate the checkpoint function definition
""" """
return f"def checkpoint_{label}({', '.join(free_vars)}):" return f"def checkpoint_{label}({', '.join(['self'] + free_vars)}):"
def _gen_ckpt_output(output_vars: List[str]) -> str: def _gen_ckpt_output(output_vars: List[str]) -> str:
@ -105,10 +106,10 @@ def _gen_ckpt_usage(label, activation_offload, input_vars, output_vars, use_reen
""" """
outputs = ', '.join(output_vars) outputs = ', '.join(output_vars)
inputs = ', '.join(input_vars) inputs = ', '.join(input_vars)
return f'{outputs} = colossalai.utils.activation_checkpoint.checkpoint(checkpoint_{label}, {activation_offload}, {inputs}, use_reentrant={use_reentrant})' return f'{outputs} = colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_{label}, {activation_offload}, {inputs}, use_reentrant={use_reentrant})'
def emit_code_with_activation_checkpoint(body, nodes, emit_node_func, delete_unused_value_func): def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, delete_unused_value_func):
# find the activation checkpoint regions # find the activation checkpoint regions
ckpt_regions = _find_ckpt_regions(nodes) ckpt_regions = _find_ckpt_regions(nodes)
start_idx = [item[0] for item in ckpt_regions] start_idx = [item[0] for item in ckpt_regions]
@ -133,27 +134,27 @@ def emit_code_with_activation_checkpoint(body, nodes, emit_node_func, delete_unu
if idx in start_idx: if idx in start_idx:
label = start_idx.index(idx) label = start_idx.index(idx)
ckpt_fn_def = _gen_ckpt_fn_def(label, input_vars[label]) ckpt_fn_def = _gen_ckpt_fn_def(label, input_vars[label])
body.append(f'{ckpt_fn_def}\n') ckpt_func.append(f'{ckpt_fn_def}\n')
within_ckpt_region = True within_ckpt_region = True
# NOTE: emit_node does not emit a string with newline. It depends # NOTE: emit_node does not emit a string with newline. It depends
# on delete_unused_values to append one # on delete_unused_values to append one
emit_node_func(node) # NOTE: currently we separate body and ckpt_func definition
# add indentation to the emmited node
if within_ckpt_region: if within_ckpt_region:
body[-1] = ' ' + body[-1] emit_node_func(node, ckpt_func)
ckpt_func[-1] = ' ' + ckpt_func[-1]
# delete unused values delete_unused_value_func(node, ckpt_func)
delete_unused_value_func(node) else:
emit_node_func(node, body)
delete_unused_value_func(node, body)
if idx in end_idx: if idx in end_idx:
# if this is the last node of the ckpt region # if this is the last node of the ckpt region
# generate return statement # generate return statement
label = end_idx.index(idx) label = end_idx.index(idx)
return_statement = _gen_ckpt_output(output_vars[label]) return_statement = _gen_ckpt_output(output_vars[label])
return_statement = f' {return_statement}\n' return_statement = f' {return_statement}\n\n'
body.append(return_statement) ckpt_func.append(return_statement)
# we need to check if the checkpoint need to offload the input # we need to check if the checkpoint need to offload the input
start_node_idx = start_idx[label] start_node_idx = start_idx[label]
@ -221,6 +222,9 @@ if CODEGEN_AVAILABLE:
globals_[global_name] = obj globals_[global_name] = obj
return global_name return global_name
# set _custom_builtins here so that we needn't import colossalai in forward
_custom_builtins["colossalai"] = _CustomBuiltin("import colossalai", colossalai)
# Pre-fill the globals table with registered builtins. # Pre-fill the globals table with registered builtins.
for name, (_, obj) in _custom_builtins.items(): for name, (_, obj) in _custom_builtins.items():
add_global(name, obj) add_global(name, obj)
@ -287,7 +291,8 @@ if CODEGEN_AVAILABLE:
map_arg(node.args, lambda n: register_last_uses(n, node)) map_arg(node.args, lambda n: register_last_uses(n, node))
map_arg(node.kwargs, lambda n: register_last_uses(n, node)) map_arg(node.kwargs, lambda n: register_last_uses(n, node))
def delete_unused_values(user: Node): # NOTE: we add a variable to distinguish body and ckpt_func
def delete_unused_values(user: Node, body):
""" """
Delete values after their last use. This ensures that values that are Delete values after their last use. This ensures that values that are
not used in the remainder of the code are freed and the memory usage not used in the remainder of the code are freed and the memory usage
@ -305,7 +310,8 @@ if CODEGEN_AVAILABLE:
else: else:
body.append('\n') body.append('\n')
def emit_node(node: Node): # NOTE: we add a variable to distinguish body and ckpt_func
def emit_node(node: Node, body):
maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}' maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}'
if node.op == 'placeholder': if node.op == 'placeholder':
assert isinstance(node.target, str) assert isinstance(node.target, str)
@ -371,7 +377,8 @@ if CODEGEN_AVAILABLE:
raise NotImplementedError(f'node: {node.op} {node.target}') raise NotImplementedError(f'node: {node.op} {node.target}')
# Modified for activation checkpointing # Modified for activation checkpointing
emit_code_with_activation_checkpoint(body, nodes, emit_node, delete_unused_values) ckpt_func = []
emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values)
if len(body) == 0: if len(body) == 0:
# If the Graph has no non-placeholder nodes, no lines for the body # If the Graph has no non-placeholder nodes, no lines for the body
@ -395,7 +402,8 @@ if CODEGEN_AVAILABLE:
# in forward function # in forward function
# TODO: Remove inline import # TODO: Remove inline import
prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0]) prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0])
prologue = prologue + "\n import colossalai" prologue = ''.join(ckpt_func) + prologue
prologue = prologue
code = ''.join(body) code = ''.join(body)
code = '\n'.join(' ' + line for line in code.split('\n')) code = '\n'.join(' ' + line for line in code.split('\n'))
@ -444,6 +452,9 @@ else:
globals_[global_name] = obj globals_[global_name] = obj
return global_name return global_name
# set _custom_builtins here so that we needn't import colossalai in forward
_custom_builtins["colossalai"] = _CustomBuiltin("import colossalai", colossalai)
# Pre-fill the globals table with registered builtins. # Pre-fill the globals table with registered builtins.
for name, (_, obj) in _custom_builtins.items(): for name, (_, obj) in _custom_builtins.items():
add_global(name, obj) add_global(name, obj)
@ -484,7 +495,8 @@ else:
map_arg(node.args, lambda n: register_last_uses(n, node)) map_arg(node.args, lambda n: register_last_uses(n, node))
map_arg(node.kwargs, lambda n: register_last_uses(n, node)) map_arg(node.kwargs, lambda n: register_last_uses(n, node))
def delete_unused_values(user: Node): # NOTE: we add a variable to distinguish body and ckpt_func
def delete_unused_values(user: Node, body):
""" """
Delete values after their last use. This ensures that values that are Delete values after their last use. This ensures that values that are
not used in the remainder of the code are freed and the memory usage not used in the remainder of the code are freed and the memory usage
@ -502,7 +514,8 @@ else:
else: else:
body.append('\n') body.append('\n')
def emit_node(node: Node): # NOTE: we add a variable to distinguish body and ckpt_func
def emit_node(node: Node, body):
maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}' maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}'
if node.op == 'placeholder': if node.op == 'placeholder':
assert isinstance(node.target, str) assert isinstance(node.target, str)
@ -562,7 +575,8 @@ else:
raise NotImplementedError(f'node: {node.op} {node.target}') raise NotImplementedError(f'node: {node.op} {node.target}')
# Modified for activation checkpointing # Modified for activation checkpointing
emit_code_with_activation_checkpoint(body, self.nodes, emit_node, delete_unused_values) ckpt_func = []
emit_code_with_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values)
if len(body) == 0: if len(body) == 0:
# If the Graph has no non-placeholder nodes, no lines for the body # If the Graph has no non-placeholder nodes, no lines for the body
@ -587,6 +601,8 @@ else:
else: else:
wrap_stmts = '' wrap_stmts = ''
ckpt_func = ''.join(ckpt_func)
# If the original function didn't have self as its first argument, we # If the original function didn't have self as its first argument, we
# would have added it. # would have added it.
if len(orig_args) == 0 or orig_args[0] != 'self': if len(orig_args) == 0 or orig_args[0] != 'self':
@ -600,7 +616,7 @@ else:
fn_code = f""" fn_code = f"""
{wrap_stmts} {wrap_stmts}
{ckpt_func}
def forward({', '.join(orig_args)}){maybe_return_annotation[0]}: def forward({', '.join(orig_args)}){maybe_return_annotation[0]}:
import colossalai
{code}""" {code}"""
return PythonCode(fn_code, globals_) return PythonCode(fn_code, globals_)

View File

@ -0,0 +1,158 @@
import os
import warnings
import torch
import torch.nn as nn
from torch.nn.modules.module import _addindent
from typing import Type, Dict, List, Any, Union, Optional, Set
from pathlib import Path
try:
from torch.fx.graph_module import GraphModule, _EvalCacheLoader, _WrappedCall, _exec_with_source, _forward_from_src
from torch.fx.graph import Graph, _PyTreeCodeGen, _is_from_torch, _custom_builtins, PythonCode
COLOGM = True
except:
from torch.fx.graph_module import GraphModule
from torch.fx.graph import Graph
COLOGM = False
if COLOGM:
class ColoGraphModule(GraphModule):
def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, class_name: str = 'GraphModule'):
super().__init__(root, graph, class_name)
def bind(self, ckpt_def, globals):
"""Bind checkpoint functions to ColoGraphModule
We need to bind our checkpoint functions to the GraphModule so
that we could correctly use self.checkpoint for GraphModule forward
"""
ckpt_code = "\n".join(ckpt_def)
globals_copy = globals.copy()
_exec_with_source(ckpt_code, globals_copy)
func_list = [func for func in globals_copy.keys() if "checkpoint" in func]
for func in func_list:
tmp_func = globals_copy[func]
setattr(self, func, tmp_func.__get__(self, self.__class__))
del globals_copy[func]
def recompile(self) -> PythonCode:
"""
Recompile this GraphModule from its ``graph`` attribute. This should be
called after editing the contained ``graph``, otherwise the generated
code of this ``GraphModule`` will be out of date.
"""
if isinstance(self._graph._codegen, _PyTreeCodeGen):
self._in_spec = self._graph._codegen.pytree_info.in_spec
self._out_spec = self._graph._codegen.pytree_info.out_spec
python_code = self._graph.python_code(root_module='self')
self._code = python_code.src
# To split ckpt functions code and forward code
_code_list = self._code.split("\n")
_fwd_def = [item for item in _code_list if "def forward" in item][0]
_fwd_idx = _code_list.index(_fwd_def)
ckpt_def = _code_list[:_fwd_idx]
self._code = "\n".join(_code_list[_fwd_idx:])
self.bind(ckpt_def, python_code.globals)
cls = type(self)
cls.forward = _forward_from_src(self._code, python_code.globals)
# Determine whether this class explicitly defines a __call__ implementation
# to wrap. If it does, save it in order to have wrapped_call invoke it.
# If it does not, wrapped_call can use a dynamic call to super() instead.
# In most cases, super().__call__ should be torch.nn.Module.__call__.
# We do not want to hold a reference to Module.__call__ here; doing so will
# bypass patching of torch.nn.Module.__call__ done while symbolic tracing.
cls_call = cls.__call__ if "__call__" in vars(cls) else None
if '_wrapped_call' not in vars(cls):
cls._wrapped_call = _WrappedCall(cls, cls_call) # type: ignore[attr-defined]
def call_wrapped(self, *args, **kwargs):
return self._wrapped_call(self, *args, **kwargs)
cls.__call__ = call_wrapped
# reset self._code to original src, otherwise to_folder will be wrong
self._code = python_code.src
return python_code
def to_folder(self, folder: Union[str, os.PathLike], module_name: str = "FxModule"):
"""Dumps out module to ``folder`` with ``module_name`` so that it can be
imported with ``from <folder> import <module_name>``
Args:
folder (Union[str, os.PathLike]): The folder to write the code out to
module_name (str): Top-level name to use for the ``Module`` while
writing out the code
"""
folder = Path(folder)
Path(folder).mkdir(exist_ok=True)
torch.save(self.state_dict(), folder / 'state_dict.pt')
tab = " " * 4
# we add import colossalai here
model_str = f"""
import torch
from torch.nn import *
import colossalai
class {module_name}(torch.nn.Module):
def __init__(self):
super().__init__()
"""
def _gen_model_repr(module_name: str, module: torch.nn.Module) -> Optional[str]:
safe_reprs = [
nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d
]
if type(module) in safe_reprs:
return f"{module.__repr__()}"
else:
return None
blobified_modules = []
for module_name, module in self.named_children():
module_str = _gen_model_repr(module_name, module)
if module_str is None:
module_file = folder / f'{module_name}.pt'
torch.save(module, module_file)
blobified_modules.append(module_name)
module_repr = module.__repr__().replace('\r', ' ').replace('\n', ' ')
module_str = f"torch.load(r'{module_file}') # {module_repr}"
model_str += f"{tab*2}self.{module_name} = {module_str}\n"
for buffer_name, buffer in self._buffers.items():
if buffer is None:
continue
model_str += f"{tab*2}self.register_buffer('{buffer_name}', torch.empty({list(buffer.shape)}, dtype={buffer.dtype}))\n"
for param_name, param in self._parameters.items():
if param is None:
continue
model_str += f"{tab*2}self.{param_name} = torch.nn.Parameter(torch.empty({list(param.shape)}, dtype={param.dtype}))\n"
model_str += f"{tab*2}self.load_state_dict(torch.load(r'{folder}/state_dict.pt'))\n"
model_str += f"{_addindent(self.code, 4)}\n"
module_file = folder / 'module.py'
module_file.write_text(model_str)
init_file = folder / '__init__.py'
init_file.write_text('from .module import *')
if len(blobified_modules) > 0:
warnings.warn("Was not able to save the following children modules as reprs -"
f"saved as pickled files instead: {blobified_modules}")
else:
class ColoGraphModule(GraphModule):
def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, class_name: str = 'GraphModule'):
super().__init__(root, graph, class_name)

View File

@ -7,6 +7,7 @@ import torchvision.models as tm
from torch.fx import GraphModule from torch.fx import GraphModule
import colossalai import colossalai
from colossalai.fx import ColoTracer from colossalai.fx import ColoTracer
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.meta_info_prop import MetaInfoProp from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.fx.passes.algorithms import chen_greedy from colossalai.fx.passes.algorithms import chen_greedy
from colossalai.utils import free_port from colossalai.utils import free_port
@ -72,7 +73,7 @@ def _run_ckpt_solver(rank):
for model_cls in MODEL_LIST: for model_cls in MODEL_LIST:
m = model_cls(num_classes=5) m = model_cls(num_classes=5)
graph = tracer.trace(root=m) graph = tracer.trace(root=m)
gm = GraphModule(copy.deepcopy(m), graph, m.__class__.__name__) gm = ColoGraphModule(copy.deepcopy(m), graph, m.__class__.__name__)
MetaInfoProp(gm).run(data) MetaInfoProp(gm).run(data)
codegen = ActivationCheckpointCodeGen() codegen = ActivationCheckpointCodeGen()
gm.graph.set_codegen(codegen) gm.graph.set_codegen(codegen)
@ -102,7 +103,7 @@ def _run_ckpt_solver_torch11(rank):
for model_cls in MODEL_LIST: for model_cls in MODEL_LIST:
m = model_cls(num_classes=5) m = model_cls(num_classes=5)
graph = tracer.trace(root=m) graph = tracer.trace(root=m)
gm = GraphModule(copy.deepcopy(m), graph, m.__class__.__name__) gm = ColoGraphModule(copy.deepcopy(m), graph, m.__class__.__name__)
MetaInfoProp(gm).run(data) MetaInfoProp(gm).run(data)
gm.graph._python_code = python_code_with_activation_checkpoint.__get__(graph) gm.graph._python_code = python_code_with_activation_checkpoint.__get__(graph)
gm = solver(gm) gm = solver(gm)
@ -114,10 +115,12 @@ def _run_ckpt_solver_torch11(rank):
@pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0') @pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0')
@pytest.mark.skip(reason="currently torch11 ColoGraphModule is not done")
def test_ckpt_solver_torch11(): def test_ckpt_solver_torch11():
mp.spawn(_run_ckpt_solver_torch11, nprocs=1) mp.spawn(_run_ckpt_solver_torch11, nprocs=1)
if __name__ == '__main__': if __name__ == '__main__':
test_ckpt_solver() _run_ckpt_solver(rank=0)
test_ckpt_solver_torch11() # test_ckpt_solver()
# test_ckpt_solver_torch11()

View File

@ -9,6 +9,7 @@ from colossalai.fx import ColoTracer
import colossalai import colossalai
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.fx.graph_module import ColoGraphModule
try: try:
from colossalai.fx.codegen import ActivationCheckpointCodeGen from colossalai.fx.codegen import ActivationCheckpointCodeGen
@ -46,7 +47,7 @@ class MyModule(torch.nn.Module):
super().__init__() super().__init__()
self.mlp1 = MLP() self.mlp1 = MLP()
self.relu = relu() self.relu = relu()
self.linear3 = torch.nn.Linear(4, 4) self.linear2 = torch.nn.Linear(4, 4)
def forward(self, x): def forward(self, x):
y1, y2 = checkpoint(self.mlp1, x) y1, y2 = checkpoint(self.mlp1, x)
@ -56,6 +57,7 @@ class MyModule(torch.nn.Module):
return F.relu(x, inplace=True) return F.relu(x, inplace=True)
y4 = checkpoint(ckpt2, x) y4 = checkpoint(ckpt2, x)
y4 = self.linear2(y4)
return y1 + y2 + y3 + y4 return y1 + y2 + y3 + y4
@ -91,15 +93,15 @@ def _run_act_ckpt_codegen(rank):
if node.name in offload_starts: if node.name in offload_starts:
setattr(node, 'activation_offload', True) setattr(node, 'activation_offload', True)
gm = GraphModule(model, graph) gm = ColoGraphModule(model, graph)
gm.recompile() gm.recompile()
# assert checkpoint function will be generated and # assert checkpoint function will be generated and
# the offload option is correct # the offload option is correct
code = graph.python_code('self').src code = graph.python_code('self').src
assert 'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_0, True, x, use_reentrant=True)' in code and \ assert 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, x, use_reentrant=True)' in code and \
'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_1, False, x, use_reentrant=False)' in code and \ 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_1, False, x, use_reentrant=False)' in code and \
'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_2, False, x, use_reentrant=False)' in code 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_2, False, x, use_reentrant=False)' in code
# recompile and verify the outputs are consistent # recompile and verify the outputs are consistent
fx_out = gm(data) fx_out = gm(data)
@ -145,14 +147,14 @@ def _run_act_ckpt_python_code_torch11(rank):
if node.name in offload_starts: if node.name in offload_starts:
setattr(node, 'activation_offload', True) setattr(node, 'activation_offload', True)
gm = GraphModule(model, graph) gm = ColoGraphModule(model, graph)
gm.recompile() gm.recompile()
# assert checkpoint function will be generated and # assert checkpoint function will be generated and
# the offload option is correct # the offload option is correct
code = graph.python_code('self').src code = graph.python_code('self').src
assert 'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_0, True, x, use_reentrant=True)' in code and \ assert 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, x, use_reentrant=True)' in code and \
'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_1, False, x, use_reentrant=False)' in code and \ 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_1, False, x, use_reentrant=False)' in code and \
'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_2, False, x, use_reentrant=False)' in code 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_2, False, x, use_reentrant=False)' in code
# recompile and verify the outputs are consistent # recompile and verify the outputs are consistent
fx_out = gm(data) fx_out = gm(data)
@ -162,11 +164,10 @@ def _run_act_ckpt_python_code_torch11(rank):
@pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0') @pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0')
@pytest.mark.skip(reason="currently torch11 ColoGraphModule is not done")
def test_act_ckpt_python_code_torch11(): def test_act_ckpt_python_code_torch11():
mp.spawn(_run_act_ckpt_python_code_torch11, nprocs=1) mp.spawn(_run_act_ckpt_python_code_torch11, nprocs=1)
if __name__ == '__main__': if __name__ == '__main__':
_run_act_ckpt_codegen(rank=0)
test_act_ckpt_codegen()
test_act_ckpt_python_code_torch11()