mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 09:38:05 +00:00
[autoparallel] move ckpt solvers to autoparallel folder / refactor code (#1764)
* [autoparallel] first move. * [autoparallel] add solver rotor. * [autoparallel] add ckpt solvers. * [autoparallel] modify codegen. * [fx] fix annotation in test. * [fx] remove check. * [autoparallel] polish docstring. * [fx] refactor MetaTensor.
This commit is contained in:
@@ -1,14 +1,37 @@
|
||||
import colossalai
|
||||
from typing import Any, Callable, Dict, Iterable, List, Tuple
|
||||
|
||||
import torch
|
||||
from typing import List, Callable, Any, Tuple, Dict, Iterable
|
||||
|
||||
import colossalai
|
||||
|
||||
try:
|
||||
from torch.fx.node import Node, Argument, map_arg, _type_repr, _get_qualified_name
|
||||
from torch.fx.graph import _Namespace, PythonCode, _custom_builtins, _is_from_torch, _format_target, magic_methods, CodeGen, _origin_type_map, inplace_methods, _CustomBuiltin
|
||||
from torch.fx.graph import (
|
||||
CodeGen,
|
||||
PythonCode,
|
||||
_custom_builtins,
|
||||
_CustomBuiltin,
|
||||
_format_target,
|
||||
_is_from_torch,
|
||||
_Namespace,
|
||||
_origin_type_map,
|
||||
inplace_methods,
|
||||
magic_methods,
|
||||
)
|
||||
from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg
|
||||
CODEGEN_AVAILABLE = True
|
||||
except:
|
||||
from torch.fx.graph import _Namespace, PythonCode, _custom_builtins, _is_from_torch, _format_target, magic_methods, _origin_type_map, _format_args, _CustomBuiltin
|
||||
from torch.fx.node import Node, Argument, map_arg, _type_repr, _get_qualified_name
|
||||
from torch.fx.graph import (
|
||||
PythonCode,
|
||||
_custom_builtins,
|
||||
_CustomBuiltin,
|
||||
_format_args,
|
||||
_format_target,
|
||||
_is_from_torch,
|
||||
_Namespace,
|
||||
_origin_type_map,
|
||||
magic_methods,
|
||||
)
|
||||
from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg
|
||||
CODEGEN_AVAILABLE = False
|
||||
|
||||
if CODEGEN_AVAILABLE:
|
||||
@@ -27,7 +50,7 @@ def _gen_saved_tensors_hooks():
|
||||
return (x.device, x.cpu())
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
def pack_hook_no_input(self, x):
|
||||
if getattr(x, "offload", True):
|
||||
return (x.device, x.cpu())
|
||||
@@ -48,11 +71,9 @@ def pack_hook_no_input(self, x):
|
||||
|
||||
def _gen_save_tensors_hooks_context(offload_input=True) -> str:
|
||||
"""Generate customized saved_tensors_hooks
|
||||
|
||||
Args:
|
||||
offload_input (bool, optional): whether we need offload input, if offload_input=False,
|
||||
offload_input (bool, optional): whether we need offload input, if offload_input=False,
|
||||
we will use self.pack_hook_no_input instead. Defaults to True.
|
||||
|
||||
Returns:
|
||||
str: generated context
|
||||
"""
|
||||
@@ -111,8 +132,8 @@ def _find_ckpt_regions(nodes: List[Node]):
|
||||
current_region = None
|
||||
|
||||
for idx, node in enumerate(nodes):
|
||||
if hasattr(node, 'activation_checkpoint'):
|
||||
act_ckpt_label = node.activation_checkpoint
|
||||
if 'activation_checkpoint' in node.meta:
|
||||
act_ckpt_label = node.meta['activation_checkpoint']
|
||||
|
||||
# this activation checkpoint label is not set yet
|
||||
# meaning this is the first node of the activation ckpt region
|
||||
@@ -129,7 +150,7 @@ def _find_ckpt_regions(nodes: List[Node]):
|
||||
current_region = act_ckpt_label
|
||||
start = idx
|
||||
end = -1
|
||||
elif current_region is not None and not hasattr(node, 'activation_checkpoint'):
|
||||
elif current_region is not None and not 'activation_checkpoint' in node.meta:
|
||||
# used to check the case below
|
||||
# node ckpt states = [ckpt, ckpt, non-ckpt]
|
||||
end = idx - 1
|
||||
@@ -144,7 +165,7 @@ def _find_ckpt_regions(nodes: List[Node]):
|
||||
|
||||
def _find_offload_regions(nodes: List[Node]):
|
||||
"""This function is to find the offload regions
|
||||
In pofo algorithm, during annotation, we will annotate the offload region with the
|
||||
In pofo algorithm, during annotation, we will annotate the offload region with the
|
||||
list in the form of [idx, offload_input, offload_bar]. idx indicates the offload
|
||||
region's index, offload_input is a bool type indicates whether we need to offload
|
||||
the input, offload_bar is a bool type indicates whether we need to offload all the
|
||||
@@ -157,8 +178,8 @@ def _find_offload_regions(nodes: List[Node]):
|
||||
current_region = None
|
||||
|
||||
for idx, node in enumerate(nodes):
|
||||
if hasattr(node, 'activation_offload') and isinstance(getattr(node, 'activation_offload', None), Iterable):
|
||||
act_offload_label = node.activation_offload
|
||||
if 'activation_offload' in node.meta and isinstance(node.meta['activation_offload'], Iterable):
|
||||
act_offload_label = node.meta['activation_offload']
|
||||
|
||||
if current_region == None:
|
||||
current_region = act_offload_label
|
||||
@@ -212,18 +233,16 @@ def _gen_ckpt_usage(label, activation_offload, input_vars, output_vars, use_reen
|
||||
|
||||
def _end_of_ckpt(node: Node, check_idx: int) -> bool:
|
||||
"""Check if the node could end the ckpt region
|
||||
|
||||
Args:
|
||||
node (Node): torch.fx.Node
|
||||
check_idx (int): the index of checkpoint level for
|
||||
check_idx (int): the index of checkpoint level for
|
||||
nested checkpoint
|
||||
|
||||
Returns:
|
||||
bool
|
||||
"""
|
||||
if hasattr(node, "activation_checkpoint"):
|
||||
if isinstance(node.activation_checkpoint, list):
|
||||
return node.activation_checkpoint[check_idx] == None
|
||||
if 'activation_checkpoint' in node.meta:
|
||||
if isinstance(node.meta['activation_checkpoint'], list):
|
||||
return node.meta['activation_checkpoint'][check_idx] == None
|
||||
else:
|
||||
return False
|
||||
else:
|
||||
@@ -232,7 +251,7 @@ def _end_of_ckpt(node: Node, check_idx: int) -> bool:
|
||||
|
||||
def _find_nested_ckpt_regions(nodes, check_idx=0):
|
||||
"""
|
||||
Find the nested checkpoint regions given a list of consecutive nodes. The outputs
|
||||
Find the nested checkpoint regions given a list of consecutive nodes. The outputs
|
||||
will be list of tuples, each tuple is in the form of (start_index, end_index).
|
||||
"""
|
||||
ckpt_regions = []
|
||||
@@ -241,11 +260,11 @@ def _find_nested_ckpt_regions(nodes, check_idx=0):
|
||||
current_region = None
|
||||
|
||||
for idx, node in enumerate(nodes):
|
||||
if hasattr(node, 'activation_checkpoint'):
|
||||
if isinstance(getattr(node, 'activation_checkpoint'), int):
|
||||
act_ckpt_label = node.activation_checkpoint
|
||||
if 'activation_checkpoint' in node.meta:
|
||||
if isinstance(node.meta['activation_checkpoint'], int):
|
||||
act_ckpt_label = node.meta['activation_checkpoint']
|
||||
else:
|
||||
act_ckpt_label = node.activation_checkpoint[check_idx]
|
||||
act_ckpt_label = node.meta['activation_checkpoint'][check_idx]
|
||||
|
||||
# this activation checkpoint label is not set yet
|
||||
# meaning this is the first node of the activation ckpt region
|
||||
@@ -287,7 +306,6 @@ def emit_ckpt_func(body,
|
||||
level=0,
|
||||
in_ckpt=False):
|
||||
"""Emit ckpt fuction in nested way
|
||||
|
||||
Args:
|
||||
body: forward code, in recursive calls, this part will be checkpoint
|
||||
functions code
|
||||
@@ -303,8 +321,8 @@ def emit_ckpt_func(body,
|
||||
inputs, outputs = _find_input_and_output_nodes(node_list)
|
||||
|
||||
# if the current checkpoint function use int as label, using old generation method
|
||||
if isinstance(node_list[0].activation_checkpoint, int):
|
||||
label = node_list[0].activation_checkpoint
|
||||
if isinstance(node_list[0].meta['activation_checkpoint'], int):
|
||||
label = node_list[0].meta['activation_checkpoint']
|
||||
ckpt_fn_def = _gen_ckpt_fn_def(label, inputs)
|
||||
ckpt_func.append(f'{ckpt_fn_def}\n')
|
||||
for node in node_list:
|
||||
@@ -313,7 +331,7 @@ def emit_ckpt_func(body,
|
||||
delete_unused_value_func(node, ckpt_func)
|
||||
|
||||
ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n')
|
||||
activation_offload = getattr(node_list[0], "activation_offload", False)
|
||||
activation_offload = node_list[0].meta.get('activation_offload', False)
|
||||
usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False)
|
||||
usage += "\n"
|
||||
body.append(usage)
|
||||
@@ -322,12 +340,12 @@ def emit_ckpt_func(body,
|
||||
else:
|
||||
# label given by each layer, e.g. if you are currently at level [0, 1, 1]
|
||||
# the label will be '0_1_1'
|
||||
label = "_".join([str(idx) for idx in node_list[0].activation_checkpoint[:level + 1]])
|
||||
label = "_".join([str(idx) for idx in node_list[0].meta['activation_checkpoint'][:level + 1]])
|
||||
ckpt_fn_def = _gen_ckpt_fn_def(label, inputs)
|
||||
ckpt_func.append(f'{ckpt_fn_def}\n')
|
||||
|
||||
# if there is more level to fetch
|
||||
if level + 1 < len(node_list[0].activation_checkpoint):
|
||||
if level + 1 < len(node_list[0].meta['activation_checkpoint']):
|
||||
ckpt_regions = _find_nested_ckpt_regions(node_list, level + 1)
|
||||
start_idx = [item[0] for item in ckpt_regions]
|
||||
end_idx = [item[1] for item in ckpt_regions]
|
||||
@@ -354,7 +372,7 @@ def emit_ckpt_func(body,
|
||||
|
||||
ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n')
|
||||
ckpt_func += ckpt_func_buffer
|
||||
activation_offload = getattr(node_list[0], "activation_offload", False)
|
||||
activation_offload = node_list[0].meta.get('activation_offload', False)
|
||||
usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + '\n'
|
||||
if in_ckpt:
|
||||
usage = ' ' + usage
|
||||
@@ -368,7 +386,7 @@ def emit_ckpt_func(body,
|
||||
delete_unused_value_func(node, ckpt_func)
|
||||
|
||||
ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n')
|
||||
activation_offload = getattr(node_list[0], "activation_offload", False)
|
||||
activation_offload = node_list[0].meta.get('activation_offload', False)
|
||||
usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + '\n'
|
||||
if in_ckpt:
|
||||
usage = ' ' + usage
|
||||
@@ -379,7 +397,6 @@ def emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_nod
|
||||
"""Emit code with nested activation checkpoint
|
||||
When we detect some of the node.activation_checkpoint is a List, we will use
|
||||
this function to emit the activation checkpoint codes.
|
||||
|
||||
Args:
|
||||
body: forward code
|
||||
ckpt_func: checkpoint functions code
|
||||
@@ -564,8 +581,8 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
|
||||
|
||||
# we need to check if the checkpoint need to offload the input
|
||||
start_node_idx = start_idx[label]
|
||||
if hasattr(node_list[start_node_idx], 'activation_offload'):
|
||||
activation_offload = node_list[start_node_idx].activation_offload
|
||||
if 'activation_offload' in node_list[start_node_idx].meta:
|
||||
activation_offload = node_list[start_node_idx].meta['activation_offload']
|
||||
else:
|
||||
activation_offload = False
|
||||
|
||||
@@ -577,8 +594,8 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
|
||||
if input_node.op != "placeholder":
|
||||
non_leaf_input = 1
|
||||
for user in input_node.users:
|
||||
if hasattr(user, "activation_checkpoint"):
|
||||
if user.activation_checkpoint == label:
|
||||
if 'activation_checkpoint' in user.meta:
|
||||
if user.meta['activation_checkpoint'] == label:
|
||||
if user.op == "call_module":
|
||||
if hasattr(user.graph.owning_module.get_submodule(user.target), "inplace"):
|
||||
use_reentrant = not user.graph.owning_module.get_submodule(user.target).inplace
|
||||
@@ -616,10 +633,8 @@ if CODEGEN_AVAILABLE:
|
||||
|
||||
def add_global(name_hint: str, obj: Any):
|
||||
"""Add an obj to be tracked as a global.
|
||||
|
||||
We call this for names that reference objects external to the
|
||||
Graph, like functions or types.
|
||||
|
||||
Returns: the global name that should be used to reference 'obj' in generated source.
|
||||
"""
|
||||
if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device
|
||||
@@ -796,7 +811,7 @@ if CODEGEN_AVAILABLE:
|
||||
|
||||
# if any node has a list of labels for activation_checkpoint, we
|
||||
# will use nested type of activation checkpoint codegen
|
||||
if any(isinstance(getattr(node, "activation_checkpoint", None), Iterable) for node in nodes):
|
||||
if any(isinstance(node.meta.get('activation_checkpoint', None), Iterable) for node in nodes):
|
||||
emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values)
|
||||
else:
|
||||
emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values)
|
||||
@@ -829,7 +844,6 @@ if CODEGEN_AVAILABLE:
|
||||
code = '\n'.join(' ' + line for line in code.split('\n'))
|
||||
fn_code = f"""
|
||||
{wrap_stmts}
|
||||
|
||||
{prologue}
|
||||
{code}"""
|
||||
return PythonCode(fn_code, globals_)
|
||||
@@ -851,10 +865,8 @@ else:
|
||||
|
||||
def add_global(name_hint: str, obj: Any):
|
||||
"""Add an obj to be tracked as a global.
|
||||
|
||||
We call this for names that reference objects external to the
|
||||
Graph, like functions or types.
|
||||
|
||||
Returns: the global name that should be used to reference 'obj' in generated source.
|
||||
"""
|
||||
if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device
|
||||
@@ -999,7 +1011,7 @@ else:
|
||||
|
||||
# if any node has a list of labels for activation_checkpoint, we
|
||||
# will use nested type of activation checkpoint codegen
|
||||
if any(isinstance(getattr(node, "activation_checkpoint", None), Iterable) for node in self.nodes):
|
||||
if any(isinstance(node.meta.get('activation_checkpoint', None), Iterable) for node in self.nodes):
|
||||
emit_code_with_nested_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values)
|
||||
else:
|
||||
emit_code_with_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values)
|
||||
@@ -1040,7 +1052,6 @@ else:
|
||||
# in forward function
|
||||
fn_code = f"""
|
||||
{wrap_stmts}
|
||||
|
||||
{ckpt_func}
|
||||
def forward({', '.join(orig_args)}){maybe_return_annotation[0]}:
|
||||
{code}"""
|
||||
|
Reference in New Issue
Block a user