mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 12:30:42 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -9,7 +9,18 @@ from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABL
|
||||
AUTOCHUNK_AVAILABLE = CODEGEN_AVAILABLE and is_compatible_with_meta()
|
||||
|
||||
if AUTOCHUNK_AVAILABLE:
|
||||
from torch.fx.graph import CodeGen, PythonCode, _custom_builtins, _CustomBuiltin, _format_target, _is_from_torch, _Namespace, _origin_type_map, inplace_methods, magic_methods
|
||||
from torch.fx.graph import (
|
||||
CodeGen,
|
||||
PythonCode,
|
||||
_custom_builtins,
|
||||
_CustomBuiltin,
|
||||
_format_target,
|
||||
_is_from_torch,
|
||||
_Namespace,
|
||||
_origin_type_map,
|
||||
inplace_methods,
|
||||
magic_methods,
|
||||
)
|
||||
|
||||
from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg
|
||||
|
||||
@@ -64,14 +75,21 @@ def _gen_loop_start(chunk_input: List[Node], chunk_output: List[Node], chunk_out
|
||||
for i in range(len(chunk_output)):
|
||||
shape_str = str(list(get_node_shape(chunk_output[i])))
|
||||
if get_node_name(chunk_output[i]) in ["split", "unbind"]:
|
||||
tensor_str = "torch.empty(%s, dtype=%s.dtype, device=%s.device), " % (shape_str, input_node.name,
|
||||
input_node.name)
|
||||
tensor_str = tensor_str * len(chunk_output[i].meta['tensor_meta'])
|
||||
tensor_str = "torch.empty(%s, dtype=%s.dtype, device=%s.device), " % (
|
||||
shape_str,
|
||||
input_node.name,
|
||||
input_node.name,
|
||||
)
|
||||
tensor_str = tensor_str * len(chunk_output[i].meta["tensor_meta"])
|
||||
tensor_str = "[" + tensor_str[:-2] + "]"
|
||||
context += "%s = %s; " % (chunk_output[i].name, tensor_str)
|
||||
else:
|
||||
context += "%s = torch.empty(%s, dtype=%s.dtype, device=%s.device); " % (chunk_output[i].name, shape_str,
|
||||
input_node.name, input_node.name)
|
||||
context += "%s = torch.empty(%s, dtype=%s.dtype, device=%s.device); " % (
|
||||
chunk_output[i].name,
|
||||
shape_str,
|
||||
input_node.name,
|
||||
input_node.name,
|
||||
)
|
||||
|
||||
out_shape = get_node_shape(chunk_output[0])
|
||||
chunk_shape = out_shape[chunk_output_dim[0]]
|
||||
@@ -79,8 +97,14 @@ def _gen_loop_start(chunk_input: List[Node], chunk_output: List[Node], chunk_out
|
||||
return context
|
||||
|
||||
|
||||
def _gen_loop_end(chunk_inputs: List[Node], chunk_non_compute_inputs: List[Node], node_list: List[Node],
|
||||
chunk_outputs_idx: int, chunk_outputs_non_tensor: List[Node], search_chunk: SearchChunk) -> str:
|
||||
def _gen_loop_end(
|
||||
chunk_inputs: List[Node],
|
||||
chunk_non_compute_inputs: List[Node],
|
||||
node_list: List[Node],
|
||||
chunk_outputs_idx: int,
|
||||
chunk_outputs_non_tensor: List[Node],
|
||||
search_chunk: SearchChunk,
|
||||
) -> str:
|
||||
"""
|
||||
Generate chunk loop end
|
||||
|
||||
@@ -148,8 +172,10 @@ def _replace_new_tensor_like_shape(
|
||||
chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node]["chunk_dim"]
|
||||
if get_node_shape(meta_node)[chunk_dim] != 1:
|
||||
source_node = meta_node.args[0].args[0]
|
||||
if (source_node not in chunk_infos[region_idx]["node_chunk_dim"]
|
||||
or chunk_infos[region_idx]["node_chunk_dim"][source_node]["chunk_dim"] is None):
|
||||
if (
|
||||
source_node not in chunk_infos[region_idx]["node_chunk_dim"]
|
||||
or chunk_infos[region_idx]["node_chunk_dim"][source_node]["chunk_dim"] is None
|
||||
):
|
||||
chunk_slice = _gen_chunk_slice_dim(chunk_dim, "chunk_idx", get_node_shape(node))
|
||||
body[-1] = _replace_name(body[-1], node.args[0].name, node.args[0].name + chunk_slice)
|
||||
return body
|
||||
@@ -203,11 +229,12 @@ def _add_node_slice(
|
||||
# outputs node
|
||||
else:
|
||||
if chunk_node.name == node.name or (chunk_node.name in [i.name for i in node.all_input_nodes]):
|
||||
chunk_slice = _gen_chunk_slice_dim(chunk_nodes_dim[region_idx][chunk_node_idx], "chunk_idx",
|
||||
get_node_shape(chunk_node))
|
||||
chunk_slice = _gen_chunk_slice_dim(
|
||||
chunk_nodes_dim[region_idx][chunk_node_idx], "chunk_idx", get_node_shape(chunk_node)
|
||||
)
|
||||
if get_node_name(chunk_node) in ["split", "unbind"]:
|
||||
split_chunk_slice = ""
|
||||
for i in range(len(chunk_node.meta['tensor_meta'])):
|
||||
for i in range(len(chunk_node.meta["tensor_meta"])):
|
||||
split_chunk_slice += "%s[%d]%s, " % (chunk_node.name, i, chunk_slice)
|
||||
split_chunk_slice = split_chunk_slice[:-2]
|
||||
body[-1] = _replace_name(body[-1], chunk_node.name, split_chunk_slice)
|
||||
@@ -216,13 +243,15 @@ def _add_node_slice(
|
||||
return body
|
||||
|
||||
|
||||
def emit_code_with_chunk(body: List[str],
|
||||
nodes: Iterable[Node],
|
||||
emit_node_func: Callable,
|
||||
delete_unused_value_func: Callable,
|
||||
search_chunk: SearchChunk,
|
||||
chunk_infos: List,
|
||||
eval_mem: bool = False):
|
||||
def emit_code_with_chunk(
|
||||
body: List[str],
|
||||
nodes: Iterable[Node],
|
||||
emit_node_func: Callable,
|
||||
delete_unused_value_func: Callable,
|
||||
search_chunk: SearchChunk,
|
||||
chunk_infos: List,
|
||||
eval_mem: bool = False,
|
||||
):
|
||||
"""
|
||||
Emit code with chunk according to chunk_infos.
|
||||
|
||||
@@ -244,9 +273,9 @@ def emit_code_with_chunk(body: List[str],
|
||||
chunk_ends = [i["region"][1] for i in chunk_infos]
|
||||
|
||||
# chunk inputs
|
||||
chunk_inputs = [i["inputs"] for i in chunk_infos] # input with chunk
|
||||
chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos] # input without chunk
|
||||
chunk_inputs_dim = [i["inputs_dim"] for i in chunk_infos] # input chunk dim
|
||||
chunk_inputs = [i["inputs"] for i in chunk_infos] # input with chunk
|
||||
chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos] # input without chunk
|
||||
chunk_inputs_dim = [i["inputs_dim"] for i in chunk_infos] # input chunk dim
|
||||
chunk_inputs_names = [j.name for i in chunk_inputs for j in i] + [j.name for i in chunk_inputs_non_chunk for j in i]
|
||||
|
||||
# chunk outputs
|
||||
@@ -275,7 +304,8 @@ def emit_code_with_chunk(body: List[str],
|
||||
chunk_outputs[region_idx],
|
||||
chunk_outputs_dim[region_idx],
|
||||
chunk_infos[region_idx]["chunk_size"],
|
||||
))
|
||||
)
|
||||
)
|
||||
|
||||
if within_chunk_region:
|
||||
emit_node_func(node, body)
|
||||
@@ -294,7 +324,8 @@ def emit_code_with_chunk(body: List[str],
|
||||
if eval_mem:
|
||||
body.append(
|
||||
" if chunk_idx == 0:\n print('%s', torch.cuda.max_memory_allocated() / 1024**2 - init_memory); torch.cuda.reset_peak_memory_stats()\n"
|
||||
% (node.name))
|
||||
% (node.name)
|
||||
)
|
||||
else:
|
||||
emit_node_func(node, body)
|
||||
if node_idx not in chunk_inputs:
|
||||
@@ -302,13 +333,21 @@ def emit_code_with_chunk(body: List[str],
|
||||
if eval_mem:
|
||||
body.append(
|
||||
"print('%s', torch.cuda.max_memory_allocated() / 1024**2 - init_memory); torch.cuda.reset_peak_memory_stats()\n"
|
||||
% (node.name))
|
||||
% (node.name)
|
||||
)
|
||||
|
||||
# generate chunk region end
|
||||
if node_idx in chunk_ends:
|
||||
body.append(
|
||||
_gen_loop_end(chunk_inputs[region_idx], chunk_inputs_non_chunk[region_idx], node_list,
|
||||
chunk_ends[region_idx], chunk_outputs_non_tensor[region_idx], search_chunk))
|
||||
_gen_loop_end(
|
||||
chunk_inputs[region_idx],
|
||||
chunk_inputs_non_chunk[region_idx],
|
||||
node_list,
|
||||
chunk_ends[region_idx],
|
||||
chunk_outputs_non_tensor[region_idx],
|
||||
search_chunk,
|
||||
)
|
||||
)
|
||||
within_chunk_region = False
|
||||
|
||||
node_idx += 1
|
||||
@@ -317,13 +356,14 @@ def emit_code_with_chunk(body: List[str],
|
||||
if AUTOCHUNK_AVAILABLE:
|
||||
|
||||
class AutoChunkCodeGen(CodeGen):
|
||||
|
||||
def __init__(self,
|
||||
meta_graph,
|
||||
max_memory: int = None,
|
||||
print_mem: bool = False,
|
||||
print_progress: bool = False,
|
||||
eval_mem: bool = False) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
meta_graph,
|
||||
max_memory: int = None,
|
||||
print_mem: bool = False,
|
||||
print_progress: bool = False,
|
||||
eval_mem: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.eval_mem = eval_mem
|
||||
# find the chunk regions
|
||||
@@ -349,7 +389,7 @@ if AUTOCHUNK_AVAILABLE:
|
||||
|
||||
Returns: the global name that should be used to reference 'obj' in generated source.
|
||||
"""
|
||||
if (_is_from_torch(obj) and obj != torch.device): # to support registering torch.device
|
||||
if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device
|
||||
# HACK: workaround for how torch custom ops are registered. We
|
||||
# can't import them like normal modules so they must retain their
|
||||
# fully qualified name.
|
||||
@@ -402,7 +442,6 @@ if AUTOCHUNK_AVAILABLE:
|
||||
return add_global(typename, o)
|
||||
|
||||
def _format_args(args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> str:
|
||||
|
||||
def _get_repr(arg):
|
||||
# Handle NamedTuples (if it has `_fields`) via add_global.
|
||||
if isinstance(arg, tuple) and hasattr(arg, "_fields"):
|
||||
@@ -457,10 +496,10 @@ if AUTOCHUNK_AVAILABLE:
|
||||
|
||||
# NOTE: we add a variable to distinguish body and ckpt_func
|
||||
def emit_node(node: Node, body):
|
||||
maybe_type_annotation = ("" if node.type is None else f" : {type_repr(node.type)}")
|
||||
maybe_type_annotation = "" if node.type is None else f" : {type_repr(node.type)}"
|
||||
if node.op == "placeholder":
|
||||
assert isinstance(node.target, str)
|
||||
maybe_default_arg = ("" if not node.args else f" = {repr(node.args[0])}")
|
||||
maybe_default_arg = "" if not node.args else f" = {repr(node.args[0])}"
|
||||
free_vars.append(f"{node.target}{maybe_type_annotation}{maybe_default_arg}")
|
||||
raw_name = node.target.replace("*", "")
|
||||
if raw_name != repr(node):
|
||||
@@ -470,42 +509,56 @@ if AUTOCHUNK_AVAILABLE:
|
||||
assert isinstance(node.target, str)
|
||||
body.append(
|
||||
f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}"
|
||||
f"({_format_args(node.args[1:], node.kwargs)})")
|
||||
f"({_format_args(node.args[1:], node.kwargs)})"
|
||||
)
|
||||
return
|
||||
elif node.op == "call_function":
|
||||
assert callable(node.target)
|
||||
# pretty print operators
|
||||
if (node.target.__module__ == "_operator" and node.target.__name__ in magic_methods):
|
||||
if node.target.__module__ == "_operator" and node.target.__name__ in magic_methods:
|
||||
assert isinstance(node.args, tuple)
|
||||
body.append(f"{repr(node)}{maybe_type_annotation} = "
|
||||
f"{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}")
|
||||
body.append(
|
||||
f"{repr(node)}{maybe_type_annotation} = "
|
||||
f"{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}"
|
||||
)
|
||||
return
|
||||
|
||||
# pretty print inplace operators; required for jit.script to work properly
|
||||
# not currently supported in normal FX graphs, but generated by torchdynamo
|
||||
if (node.target.__module__ == "_operator" and node.target.__name__ in inplace_methods):
|
||||
body.append(f"{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; "
|
||||
f"{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}")
|
||||
if node.target.__module__ == "_operator" and node.target.__name__ in inplace_methods:
|
||||
body.append(
|
||||
f"{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; "
|
||||
f"{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}"
|
||||
)
|
||||
return
|
||||
|
||||
qualified_name = _get_qualified_name(node.target)
|
||||
global_name = add_global(qualified_name, node.target)
|
||||
# special case for getattr: node.args could be 2-argument or 3-argument
|
||||
# 2-argument: attribute access; 3-argument: fall through to attrib function call with default value
|
||||
if (global_name == "getattr" and isinstance(node.args, tuple) and isinstance(node.args[1], str)
|
||||
and node.args[1].isidentifier() and len(node.args) == 2):
|
||||
if (
|
||||
global_name == "getattr"
|
||||
and isinstance(node.args, tuple)
|
||||
and isinstance(node.args[1], str)
|
||||
and node.args[1].isidentifier()
|
||||
and len(node.args) == 2
|
||||
):
|
||||
body.append(
|
||||
f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}")
|
||||
f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}"
|
||||
)
|
||||
return
|
||||
body.append(
|
||||
f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})")
|
||||
f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})"
|
||||
)
|
||||
if node.meta.get("is_wrapped", False):
|
||||
wrapped_fns.setdefault(global_name)
|
||||
return
|
||||
elif node.op == "call_module":
|
||||
assert isinstance(node.target, str)
|
||||
body.append(f"{repr(node)}{maybe_type_annotation} = "
|
||||
f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})")
|
||||
body.append(
|
||||
f"{repr(node)}{maybe_type_annotation} = "
|
||||
f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})"
|
||||
)
|
||||
return
|
||||
elif node.op == "get_attr":
|
||||
assert isinstance(node.target, str)
|
||||
@@ -523,8 +576,9 @@ if AUTOCHUNK_AVAILABLE:
|
||||
|
||||
# if any node has a list of labels for activation_checkpoint, we
|
||||
# will use nested type of activation checkpoint codegen
|
||||
emit_code_with_chunk(body, nodes, emit_node, delete_unused_values, self.search_chunk, self.chunk_infos,
|
||||
self.eval_mem)
|
||||
emit_code_with_chunk(
|
||||
body, nodes, emit_node, delete_unused_values, self.search_chunk, self.chunk_infos, self.eval_mem
|
||||
)
|
||||
|
||||
if len(body) == 0:
|
||||
# If the Graph has no non-placeholder nodes, no lines for the body
|
||||
|
Reference in New Issue
Block a user