[autoparallel] move ckpt solvers to autoparallel folder / refactor code (#1764)

* [autoparallel] first move.

* [autoparallel] add solver rotor.

* [autoparallel] add ckpt solvers.

* [autoparallel] modify codegen.

* [fx] fix annotation in test.

* [fx] remove check.

* [autoparallel] polish docstring.

* [fx] refactor MetaTensor.
This commit is contained in:
Super Daniel
2022-11-01 10:43:15 +08:00
committed by GitHub
parent 2b859502d5
commit 1e88811c7a
16 changed files with 1025 additions and 119 deletions

View File

@@ -1,14 +1,37 @@
import colossalai
from typing import Any, Callable, Dict, Iterable, List, Tuple
import torch
from typing import List, Callable, Any, Tuple, Dict, Iterable
import colossalai
try:
from torch.fx.node import Node, Argument, map_arg, _type_repr, _get_qualified_name
from torch.fx.graph import _Namespace, PythonCode, _custom_builtins, _is_from_torch, _format_target, magic_methods, CodeGen, _origin_type_map, inplace_methods, _CustomBuiltin
from torch.fx.graph import (
CodeGen,
PythonCode,
_custom_builtins,
_CustomBuiltin,
_format_target,
_is_from_torch,
_Namespace,
_origin_type_map,
inplace_methods,
magic_methods,
)
from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg
CODEGEN_AVAILABLE = True
except:
from torch.fx.graph import _Namespace, PythonCode, _custom_builtins, _is_from_torch, _format_target, magic_methods, _origin_type_map, _format_args, _CustomBuiltin
from torch.fx.node import Node, Argument, map_arg, _type_repr, _get_qualified_name
from torch.fx.graph import (
PythonCode,
_custom_builtins,
_CustomBuiltin,
_format_args,
_format_target,
_is_from_torch,
_Namespace,
_origin_type_map,
magic_methods,
)
from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg
CODEGEN_AVAILABLE = False
if CODEGEN_AVAILABLE:
@@ -27,7 +50,7 @@ def _gen_saved_tensors_hooks():
return (x.device, x.cpu())
else:
return x
def pack_hook_no_input(self, x):
if getattr(x, "offload", True):
return (x.device, x.cpu())
@@ -48,11 +71,9 @@ def pack_hook_no_input(self, x):
def _gen_save_tensors_hooks_context(offload_input=True) -> str:
"""Generate customized saved_tensors_hooks
Args:
offload_input (bool, optional): whether we need offload input, if offload_input=False,
offload_input (bool, optional): whether we need offload input, if offload_input=False,
we will use self.pack_hook_no_input instead. Defaults to True.
Returns:
str: generated context
"""
@@ -111,8 +132,8 @@ def _find_ckpt_regions(nodes: List[Node]):
current_region = None
for idx, node in enumerate(nodes):
if hasattr(node, 'activation_checkpoint'):
act_ckpt_label = node.activation_checkpoint
if 'activation_checkpoint' in node.meta:
act_ckpt_label = node.meta['activation_checkpoint']
# this activation checkpoint label is not set yet
# meaning this is the first node of the activation ckpt region
@@ -129,7 +150,7 @@ def _find_ckpt_regions(nodes: List[Node]):
current_region = act_ckpt_label
start = idx
end = -1
elif current_region is not None and not hasattr(node, 'activation_checkpoint'):
elif current_region is not None and not 'activation_checkpoint' in node.meta:
# used to check the case below
# node ckpt states = [ckpt, ckpt, non-ckpt]
end = idx - 1
@@ -144,7 +165,7 @@ def _find_ckpt_regions(nodes: List[Node]):
def _find_offload_regions(nodes: List[Node]):
"""This function is to find the offload regions
In pofo algorithm, during annotation, we will annotate the offload region with the
In pofo algorithm, during annotation, we will annotate the offload region with the
list in the form of [idx, offload_input, offload_bar]. idx indicates the offload
region's index, offload_input is a bool type indicates whether we need to offload
the input, offload_bar is a bool type indicates whether we need to offload all the
@@ -157,8 +178,8 @@ def _find_offload_regions(nodes: List[Node]):
current_region = None
for idx, node in enumerate(nodes):
if hasattr(node, 'activation_offload') and isinstance(getattr(node, 'activation_offload', None), Iterable):
act_offload_label = node.activation_offload
if 'activation_offload' in node.meta and isinstance(node.meta['activation_offload'], Iterable):
act_offload_label = node.meta['activation_offload']
if current_region == None:
current_region = act_offload_label
@@ -212,18 +233,16 @@ def _gen_ckpt_usage(label, activation_offload, input_vars, output_vars, use_reen
def _end_of_ckpt(node: Node, check_idx: int) -> bool:
"""Check if the node could end the ckpt region
Args:
node (Node): torch.fx.Node
check_idx (int): the index of checkpoint level for
check_idx (int): the index of checkpoint level for
nested checkpoint
Returns:
bool
"""
if hasattr(node, "activation_checkpoint"):
if isinstance(node.activation_checkpoint, list):
return node.activation_checkpoint[check_idx] == None
if 'activation_checkpoint' in node.meta:
if isinstance(node.meta['activation_checkpoint'], list):
return node.meta['activation_checkpoint'][check_idx] == None
else:
return False
else:
@@ -232,7 +251,7 @@ def _end_of_ckpt(node: Node, check_idx: int) -> bool:
def _find_nested_ckpt_regions(nodes, check_idx=0):
"""
Find the nested checkpoint regions given a list of consecutive nodes. The outputs
Find the nested checkpoint regions given a list of consecutive nodes. The outputs
will be list of tuples, each tuple is in the form of (start_index, end_index).
"""
ckpt_regions = []
@@ -241,11 +260,11 @@ def _find_nested_ckpt_regions(nodes, check_idx=0):
current_region = None
for idx, node in enumerate(nodes):
if hasattr(node, 'activation_checkpoint'):
if isinstance(getattr(node, 'activation_checkpoint'), int):
act_ckpt_label = node.activation_checkpoint
if 'activation_checkpoint' in node.meta:
if isinstance(node.meta['activation_checkpoint'], int):
act_ckpt_label = node.meta['activation_checkpoint']
else:
act_ckpt_label = node.activation_checkpoint[check_idx]
act_ckpt_label = node.meta['activation_checkpoint'][check_idx]
# this activation checkpoint label is not set yet
# meaning this is the first node of the activation ckpt region
@@ -287,7 +306,6 @@ def emit_ckpt_func(body,
level=0,
in_ckpt=False):
"""Emit ckpt fuction in nested way
Args:
body: forward code, in recursive calls, this part will be checkpoint
functions code
@@ -303,8 +321,8 @@ def emit_ckpt_func(body,
inputs, outputs = _find_input_and_output_nodes(node_list)
# if the current checkpoint function use int as label, using old generation method
if isinstance(node_list[0].activation_checkpoint, int):
label = node_list[0].activation_checkpoint
if isinstance(node_list[0].meta['activation_checkpoint'], int):
label = node_list[0].meta['activation_checkpoint']
ckpt_fn_def = _gen_ckpt_fn_def(label, inputs)
ckpt_func.append(f'{ckpt_fn_def}\n')
for node in node_list:
@@ -313,7 +331,7 @@ def emit_ckpt_func(body,
delete_unused_value_func(node, ckpt_func)
ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n')
activation_offload = getattr(node_list[0], "activation_offload", False)
activation_offload = node_list[0].meta.get('activation_offload', False)
usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False)
usage += "\n"
body.append(usage)
@@ -322,12 +340,12 @@ def emit_ckpt_func(body,
else:
# label given by each layer, e.g. if you are currently at level [0, 1, 1]
# the label will be '0_1_1'
label = "_".join([str(idx) for idx in node_list[0].activation_checkpoint[:level + 1]])
label = "_".join([str(idx) for idx in node_list[0].meta['activation_checkpoint'][:level + 1]])
ckpt_fn_def = _gen_ckpt_fn_def(label, inputs)
ckpt_func.append(f'{ckpt_fn_def}\n')
# if there is more level to fetch
if level + 1 < len(node_list[0].activation_checkpoint):
if level + 1 < len(node_list[0].meta['activation_checkpoint']):
ckpt_regions = _find_nested_ckpt_regions(node_list, level + 1)
start_idx = [item[0] for item in ckpt_regions]
end_idx = [item[1] for item in ckpt_regions]
@@ -354,7 +372,7 @@ def emit_ckpt_func(body,
ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n')
ckpt_func += ckpt_func_buffer
activation_offload = getattr(node_list[0], "activation_offload", False)
activation_offload = node_list[0].meta.get('activation_offload', False)
usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + '\n'
if in_ckpt:
usage = ' ' + usage
@@ -368,7 +386,7 @@ def emit_ckpt_func(body,
delete_unused_value_func(node, ckpt_func)
ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n')
activation_offload = getattr(node_list[0], "activation_offload", False)
activation_offload = node_list[0].meta.get('activation_offload', False)
usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + '\n'
if in_ckpt:
usage = ' ' + usage
@@ -379,7 +397,6 @@ def emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_nod
"""Emit code with nested activation checkpoint
When we detect some of the node.activation_checkpoint is a List, we will use
this function to emit the activation checkpoint codes.
Args:
body: forward code
ckpt_func: checkpoint functions code
@@ -564,8 +581,8 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
# we need to check if the checkpoint need to offload the input
start_node_idx = start_idx[label]
if hasattr(node_list[start_node_idx], 'activation_offload'):
activation_offload = node_list[start_node_idx].activation_offload
if 'activation_offload' in node_list[start_node_idx].meta:
activation_offload = node_list[start_node_idx].meta['activation_offload']
else:
activation_offload = False
@@ -577,8 +594,8 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
if input_node.op != "placeholder":
non_leaf_input = 1
for user in input_node.users:
if hasattr(user, "activation_checkpoint"):
if user.activation_checkpoint == label:
if 'activation_checkpoint' in user.meta:
if user.meta['activation_checkpoint'] == label:
if user.op == "call_module":
if hasattr(user.graph.owning_module.get_submodule(user.target), "inplace"):
use_reentrant = not user.graph.owning_module.get_submodule(user.target).inplace
@@ -616,10 +633,8 @@ if CODEGEN_AVAILABLE:
def add_global(name_hint: str, obj: Any):
"""Add an obj to be tracked as a global.
We call this for names that reference objects external to the
Graph, like functions or types.
Returns: the global name that should be used to reference 'obj' in generated source.
"""
if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device
@@ -796,7 +811,7 @@ if CODEGEN_AVAILABLE:
# if any node has a list of labels for activation_checkpoint, we
# will use nested type of activation checkpoint codegen
if any(isinstance(getattr(node, "activation_checkpoint", None), Iterable) for node in nodes):
if any(isinstance(node.meta.get('activation_checkpoint', None), Iterable) for node in nodes):
emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values)
else:
emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values)
@@ -829,7 +844,6 @@ if CODEGEN_AVAILABLE:
code = '\n'.join(' ' + line for line in code.split('\n'))
fn_code = f"""
{wrap_stmts}
{prologue}
{code}"""
return PythonCode(fn_code, globals_)
@@ -851,10 +865,8 @@ else:
def add_global(name_hint: str, obj: Any):
"""Add an obj to be tracked as a global.
We call this for names that reference objects external to the
Graph, like functions or types.
Returns: the global name that should be used to reference 'obj' in generated source.
"""
if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device
@@ -999,7 +1011,7 @@ else:
# if any node has a list of labels for activation_checkpoint, we
# will use nested type of activation checkpoint codegen
if any(isinstance(getattr(node, "activation_checkpoint", None), Iterable) for node in self.nodes):
if any(isinstance(node.meta.get('activation_checkpoint', None), Iterable) for node in self.nodes):
emit_code_with_nested_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values)
else:
emit_code_with_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values)
@@ -1040,7 +1052,6 @@ else:
# in forward function
fn_code = f"""
{wrap_stmts}
{ckpt_func}
def forward({', '.join(orig_args)}){maybe_return_annotation[0]}:
{code}"""