mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-13 13:11:05 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -9,7 +9,18 @@ from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABL
|
||||
AUTOCHUNK_AVAILABLE = CODEGEN_AVAILABLE and is_compatible_with_meta()
|
||||
|
||||
if AUTOCHUNK_AVAILABLE:
|
||||
from torch.fx.graph import CodeGen, PythonCode, _custom_builtins, _CustomBuiltin, _format_target, _is_from_torch, _Namespace, _origin_type_map, inplace_methods, magic_methods
|
||||
from torch.fx.graph import (
|
||||
CodeGen,
|
||||
PythonCode,
|
||||
_custom_builtins,
|
||||
_CustomBuiltin,
|
||||
_format_target,
|
||||
_is_from_torch,
|
||||
_Namespace,
|
||||
_origin_type_map,
|
||||
inplace_methods,
|
||||
magic_methods,
|
||||
)
|
||||
|
||||
from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg
|
||||
|
||||
@@ -64,14 +75,21 @@ def _gen_loop_start(chunk_input: List[Node], chunk_output: List[Node], chunk_out
|
||||
for i in range(len(chunk_output)):
|
||||
shape_str = str(list(get_node_shape(chunk_output[i])))
|
||||
if get_node_name(chunk_output[i]) in ["split", "unbind"]:
|
||||
tensor_str = "torch.empty(%s, dtype=%s.dtype, device=%s.device), " % (shape_str, input_node.name,
|
||||
input_node.name)
|
||||
tensor_str = tensor_str * len(chunk_output[i].meta['tensor_meta'])
|
||||
tensor_str = "torch.empty(%s, dtype=%s.dtype, device=%s.device), " % (
|
||||
shape_str,
|
||||
input_node.name,
|
||||
input_node.name,
|
||||
)
|
||||
tensor_str = tensor_str * len(chunk_output[i].meta["tensor_meta"])
|
||||
tensor_str = "[" + tensor_str[:-2] + "]"
|
||||
context += "%s = %s; " % (chunk_output[i].name, tensor_str)
|
||||
else:
|
||||
context += "%s = torch.empty(%s, dtype=%s.dtype, device=%s.device); " % (chunk_output[i].name, shape_str,
|
||||
input_node.name, input_node.name)
|
||||
context += "%s = torch.empty(%s, dtype=%s.dtype, device=%s.device); " % (
|
||||
chunk_output[i].name,
|
||||
shape_str,
|
||||
input_node.name,
|
||||
input_node.name,
|
||||
)
|
||||
|
||||
out_shape = get_node_shape(chunk_output[0])
|
||||
chunk_shape = out_shape[chunk_output_dim[0]]
|
||||
@@ -79,8 +97,14 @@ def _gen_loop_start(chunk_input: List[Node], chunk_output: List[Node], chunk_out
|
||||
return context
|
||||
|
||||
|
||||
def _gen_loop_end(chunk_inputs: List[Node], chunk_non_compute_inputs: List[Node], node_list: List[Node],
|
||||
chunk_outputs_idx: int, chunk_outputs_non_tensor: List[Node], search_chunk: SearchChunk) -> str:
|
||||
def _gen_loop_end(
|
||||
chunk_inputs: List[Node],
|
||||
chunk_non_compute_inputs: List[Node],
|
||||
node_list: List[Node],
|
||||
chunk_outputs_idx: int,
|
||||
chunk_outputs_non_tensor: List[Node],
|
||||
search_chunk: SearchChunk,
|
||||
) -> str:
|
||||
"""
|
||||
Generate chunk loop end
|
||||
|
||||
@@ -148,8 +172,10 @@ def _replace_new_tensor_like_shape(
|
||||
chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node]["chunk_dim"]
|
||||
if get_node_shape(meta_node)[chunk_dim] != 1:
|
||||
source_node = meta_node.args[0].args[0]
|
||||
if (source_node not in chunk_infos[region_idx]["node_chunk_dim"]
|
||||
or chunk_infos[region_idx]["node_chunk_dim"][source_node]["chunk_dim"] is None):
|
||||
if (
|
||||
source_node not in chunk_infos[region_idx]["node_chunk_dim"]
|
||||
or chunk_infos[region_idx]["node_chunk_dim"][source_node]["chunk_dim"] is None
|
||||
):
|
||||
chunk_slice = _gen_chunk_slice_dim(chunk_dim, "chunk_idx", get_node_shape(node))
|
||||
body[-1] = _replace_name(body[-1], node.args[0].name, node.args[0].name + chunk_slice)
|
||||
return body
|
||||
@@ -203,11 +229,12 @@ def _add_node_slice(
|
||||
# outputs node
|
||||
else:
|
||||
if chunk_node.name == node.name or (chunk_node.name in [i.name for i in node.all_input_nodes]):
|
||||
chunk_slice = _gen_chunk_slice_dim(chunk_nodes_dim[region_idx][chunk_node_idx], "chunk_idx",
|
||||
get_node_shape(chunk_node))
|
||||
chunk_slice = _gen_chunk_slice_dim(
|
||||
chunk_nodes_dim[region_idx][chunk_node_idx], "chunk_idx", get_node_shape(chunk_node)
|
||||
)
|
||||
if get_node_name(chunk_node) in ["split", "unbind"]:
|
||||
split_chunk_slice = ""
|
||||
for i in range(len(chunk_node.meta['tensor_meta'])):
|
||||
for i in range(len(chunk_node.meta["tensor_meta"])):
|
||||
split_chunk_slice += "%s[%d]%s, " % (chunk_node.name, i, chunk_slice)
|
||||
split_chunk_slice = split_chunk_slice[:-2]
|
||||
body[-1] = _replace_name(body[-1], chunk_node.name, split_chunk_slice)
|
||||
@@ -216,13 +243,15 @@ def _add_node_slice(
|
||||
return body
|
||||
|
||||
|
||||
def emit_code_with_chunk(body: List[str],
|
||||
nodes: Iterable[Node],
|
||||
emit_node_func: Callable,
|
||||
delete_unused_value_func: Callable,
|
||||
search_chunk: SearchChunk,
|
||||
chunk_infos: List,
|
||||
eval_mem: bool = False):
|
||||
def emit_code_with_chunk(
|
||||
body: List[str],
|
||||
nodes: Iterable[Node],
|
||||
emit_node_func: Callable,
|
||||
delete_unused_value_func: Callable,
|
||||
search_chunk: SearchChunk,
|
||||
chunk_infos: List,
|
||||
eval_mem: bool = False,
|
||||
):
|
||||
"""
|
||||
Emit code with chunk according to chunk_infos.
|
||||
|
||||
@@ -244,9 +273,9 @@ def emit_code_with_chunk(body: List[str],
|
||||
chunk_ends = [i["region"][1] for i in chunk_infos]
|
||||
|
||||
# chunk inputs
|
||||
chunk_inputs = [i["inputs"] for i in chunk_infos] # input with chunk
|
||||
chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos] # input without chunk
|
||||
chunk_inputs_dim = [i["inputs_dim"] for i in chunk_infos] # input chunk dim
|
||||
chunk_inputs = [i["inputs"] for i in chunk_infos] # input with chunk
|
||||
chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos] # input without chunk
|
||||
chunk_inputs_dim = [i["inputs_dim"] for i in chunk_infos] # input chunk dim
|
||||
chunk_inputs_names = [j.name for i in chunk_inputs for j in i] + [j.name for i in chunk_inputs_non_chunk for j in i]
|
||||
|
||||
# chunk outputs
|
||||
@@ -275,7 +304,8 @@ def emit_code_with_chunk(body: List[str],
|
||||
chunk_outputs[region_idx],
|
||||
chunk_outputs_dim[region_idx],
|
||||
chunk_infos[region_idx]["chunk_size"],
|
||||
))
|
||||
)
|
||||
)
|
||||
|
||||
if within_chunk_region:
|
||||
emit_node_func(node, body)
|
||||
@@ -294,7 +324,8 @@ def emit_code_with_chunk(body: List[str],
|
||||
if eval_mem:
|
||||
body.append(
|
||||
" if chunk_idx == 0:\n print('%s', torch.cuda.max_memory_allocated() / 1024**2 - init_memory); torch.cuda.reset_peak_memory_stats()\n"
|
||||
% (node.name))
|
||||
% (node.name)
|
||||
)
|
||||
else:
|
||||
emit_node_func(node, body)
|
||||
if node_idx not in chunk_inputs:
|
||||
@@ -302,13 +333,21 @@ def emit_code_with_chunk(body: List[str],
|
||||
if eval_mem:
|
||||
body.append(
|
||||
"print('%s', torch.cuda.max_memory_allocated() / 1024**2 - init_memory); torch.cuda.reset_peak_memory_stats()\n"
|
||||
% (node.name))
|
||||
% (node.name)
|
||||
)
|
||||
|
||||
# generate chunk region end
|
||||
if node_idx in chunk_ends:
|
||||
body.append(
|
||||
_gen_loop_end(chunk_inputs[region_idx], chunk_inputs_non_chunk[region_idx], node_list,
|
||||
chunk_ends[region_idx], chunk_outputs_non_tensor[region_idx], search_chunk))
|
||||
_gen_loop_end(
|
||||
chunk_inputs[region_idx],
|
||||
chunk_inputs_non_chunk[region_idx],
|
||||
node_list,
|
||||
chunk_ends[region_idx],
|
||||
chunk_outputs_non_tensor[region_idx],
|
||||
search_chunk,
|
||||
)
|
||||
)
|
||||
within_chunk_region = False
|
||||
|
||||
node_idx += 1
|
||||
@@ -317,13 +356,14 @@ def emit_code_with_chunk(body: List[str],
|
||||
if AUTOCHUNK_AVAILABLE:
|
||||
|
||||
class AutoChunkCodeGen(CodeGen):
|
||||
|
||||
def __init__(self,
|
||||
meta_graph,
|
||||
max_memory: int = None,
|
||||
print_mem: bool = False,
|
||||
print_progress: bool = False,
|
||||
eval_mem: bool = False) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
meta_graph,
|
||||
max_memory: int = None,
|
||||
print_mem: bool = False,
|
||||
print_progress: bool = False,
|
||||
eval_mem: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.eval_mem = eval_mem
|
||||
# find the chunk regions
|
||||
@@ -349,7 +389,7 @@ if AUTOCHUNK_AVAILABLE:
|
||||
|
||||
Returns: the global name that should be used to reference 'obj' in generated source.
|
||||
"""
|
||||
if (_is_from_torch(obj) and obj != torch.device): # to support registering torch.device
|
||||
if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device
|
||||
# HACK: workaround for how torch custom ops are registered. We
|
||||
# can't import them like normal modules so they must retain their
|
||||
# fully qualified name.
|
||||
@@ -402,7 +442,6 @@ if AUTOCHUNK_AVAILABLE:
|
||||
return add_global(typename, o)
|
||||
|
||||
def _format_args(args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> str:
|
||||
|
||||
def _get_repr(arg):
|
||||
# Handle NamedTuples (if it has `_fields`) via add_global.
|
||||
if isinstance(arg, tuple) and hasattr(arg, "_fields"):
|
||||
@@ -457,10 +496,10 @@ if AUTOCHUNK_AVAILABLE:
|
||||
|
||||
# NOTE: we add a variable to distinguish body and ckpt_func
|
||||
def emit_node(node: Node, body):
|
||||
maybe_type_annotation = ("" if node.type is None else f" : {type_repr(node.type)}")
|
||||
maybe_type_annotation = "" if node.type is None else f" : {type_repr(node.type)}"
|
||||
if node.op == "placeholder":
|
||||
assert isinstance(node.target, str)
|
||||
maybe_default_arg = ("" if not node.args else f" = {repr(node.args[0])}")
|
||||
maybe_default_arg = "" if not node.args else f" = {repr(node.args[0])}"
|
||||
free_vars.append(f"{node.target}{maybe_type_annotation}{maybe_default_arg}")
|
||||
raw_name = node.target.replace("*", "")
|
||||
if raw_name != repr(node):
|
||||
@@ -470,42 +509,56 @@ if AUTOCHUNK_AVAILABLE:
|
||||
assert isinstance(node.target, str)
|
||||
body.append(
|
||||
f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}"
|
||||
f"({_format_args(node.args[1:], node.kwargs)})")
|
||||
f"({_format_args(node.args[1:], node.kwargs)})"
|
||||
)
|
||||
return
|
||||
elif node.op == "call_function":
|
||||
assert callable(node.target)
|
||||
# pretty print operators
|
||||
if (node.target.__module__ == "_operator" and node.target.__name__ in magic_methods):
|
||||
if node.target.__module__ == "_operator" and node.target.__name__ in magic_methods:
|
||||
assert isinstance(node.args, tuple)
|
||||
body.append(f"{repr(node)}{maybe_type_annotation} = "
|
||||
f"{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}")
|
||||
body.append(
|
||||
f"{repr(node)}{maybe_type_annotation} = "
|
||||
f"{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}"
|
||||
)
|
||||
return
|
||||
|
||||
# pretty print inplace operators; required for jit.script to work properly
|
||||
# not currently supported in normal FX graphs, but generated by torchdynamo
|
||||
if (node.target.__module__ == "_operator" and node.target.__name__ in inplace_methods):
|
||||
body.append(f"{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; "
|
||||
f"{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}")
|
||||
if node.target.__module__ == "_operator" and node.target.__name__ in inplace_methods:
|
||||
body.append(
|
||||
f"{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; "
|
||||
f"{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}"
|
||||
)
|
||||
return
|
||||
|
||||
qualified_name = _get_qualified_name(node.target)
|
||||
global_name = add_global(qualified_name, node.target)
|
||||
# special case for getattr: node.args could be 2-argument or 3-argument
|
||||
# 2-argument: attribute access; 3-argument: fall through to attrib function call with default value
|
||||
if (global_name == "getattr" and isinstance(node.args, tuple) and isinstance(node.args[1], str)
|
||||
and node.args[1].isidentifier() and len(node.args) == 2):
|
||||
if (
|
||||
global_name == "getattr"
|
||||
and isinstance(node.args, tuple)
|
||||
and isinstance(node.args[1], str)
|
||||
and node.args[1].isidentifier()
|
||||
and len(node.args) == 2
|
||||
):
|
||||
body.append(
|
||||
f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}")
|
||||
f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}"
|
||||
)
|
||||
return
|
||||
body.append(
|
||||
f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})")
|
||||
f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})"
|
||||
)
|
||||
if node.meta.get("is_wrapped", False):
|
||||
wrapped_fns.setdefault(global_name)
|
||||
return
|
||||
elif node.op == "call_module":
|
||||
assert isinstance(node.target, str)
|
||||
body.append(f"{repr(node)}{maybe_type_annotation} = "
|
||||
f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})")
|
||||
body.append(
|
||||
f"{repr(node)}{maybe_type_annotation} = "
|
||||
f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})"
|
||||
)
|
||||
return
|
||||
elif node.op == "get_attr":
|
||||
assert isinstance(node.target, str)
|
||||
@@ -523,8 +576,9 @@ if AUTOCHUNK_AVAILABLE:
|
||||
|
||||
# if any node has a list of labels for activation_checkpoint, we
|
||||
# will use nested type of activation checkpoint codegen
|
||||
emit_code_with_chunk(body, nodes, emit_node, delete_unused_values, self.search_chunk, self.chunk_infos,
|
||||
self.eval_mem)
|
||||
emit_code_with_chunk(
|
||||
body, nodes, emit_node, delete_unused_values, self.search_chunk, self.chunk_infos, self.eval_mem
|
||||
)
|
||||
|
||||
if len(body) == 0:
|
||||
# If the Graph has no non-placeholder nodes, no lines for the body
|
||||
|
@@ -1,11 +1,8 @@
|
||||
import copy
|
||||
from typing import Any, Callable, Dict, Iterable, List, Tuple
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
from torch.fx.node import Node
|
||||
|
||||
from colossalai.fx.profiler import activation_size, parameter_size
|
||||
|
||||
from .utils import NodeMgr, get_node_shape, is_non_memory_node
|
||||
|
||||
|
||||
@@ -62,12 +59,9 @@ class EstimateMemory(object):
|
||||
delete_node_dict[node] = max(node_user_idx)
|
||||
return delete_node_dict
|
||||
|
||||
def _remove_deactive_node(self,
|
||||
user_idx: int,
|
||||
user: Node,
|
||||
active_nodes: List,
|
||||
delete_node_dict: List,
|
||||
kept_nodes: List = None) -> None:
|
||||
def _remove_deactive_node(
|
||||
self, user_idx: int, user: Node, active_nodes: List, delete_node_dict: List, kept_nodes: List = None
|
||||
) -> None:
|
||||
"""
|
||||
remove deactivate nodes from active nodes
|
||||
"""
|
||||
@@ -169,7 +163,7 @@ class EstimateMemory(object):
|
||||
use_chunk = True if chunk_infos is not None else False
|
||||
chunk_within = False
|
||||
chunk_region_idx = None
|
||||
chunk_ratio = 1 # use it to estimate chunk mem
|
||||
chunk_ratio = 1 # use it to estimate chunk mem
|
||||
chunk_inputs_all = []
|
||||
|
||||
if use_chunk:
|
||||
@@ -184,7 +178,6 @@ class EstimateMemory(object):
|
||||
chunk_sizes = [i["chunk_size"] if "chunk_size" in i else 1 for i in chunk_infos]
|
||||
|
||||
for idx, node in enumerate(node_mgr.get_node_list()):
|
||||
|
||||
# if node in chunk start nodes, change chunk ratio and add chunk_tensor
|
||||
if use_chunk and idx in chunk_starts:
|
||||
chunk_within = True
|
||||
@@ -193,8 +186,9 @@ class EstimateMemory(object):
|
||||
|
||||
# determine chunk ratio for current node
|
||||
if chunk_within:
|
||||
chunk_ratio = self._get_chunk_ratio(node, chunk_node_dim[chunk_region_idx],
|
||||
chunk_sizes[chunk_region_idx])
|
||||
chunk_ratio = self._get_chunk_ratio(
|
||||
node, chunk_node_dim[chunk_region_idx], chunk_sizes[chunk_region_idx]
|
||||
)
|
||||
|
||||
# add current node as active node
|
||||
self._add_active_node(node, active_nodes, chunk_ratio)
|
||||
@@ -222,7 +216,7 @@ class EstimateMemory(object):
|
||||
|
||||
# if node in chunk end nodes, restore chunk settings
|
||||
if use_chunk and idx in chunk_ends:
|
||||
self._remove_deactive_node(idx, node, active_nodes, delete_node_dict) # dont provide kept nodes now
|
||||
self._remove_deactive_node(idx, node, active_nodes, delete_node_dict) # dont provide kept nodes now
|
||||
chunk_within = False
|
||||
chunk_ratio = 1
|
||||
chunk_region_idx = None
|
||||
|
@@ -8,7 +8,7 @@ from .reorder_graph import ReorderGraph
|
||||
from .select_chunk import SelectChunk
|
||||
from .trace_flow import TraceFlow
|
||||
from .trace_indice import TraceIndice
|
||||
from .utils import NodeMgr, get_logger, get_node_shape, is_non_compute_node, is_non_compute_node_except_placeholder
|
||||
from .utils import NodeMgr, get_logger, is_non_compute_node, is_non_compute_node_except_placeholder
|
||||
|
||||
|
||||
class SearchChunk(object):
|
||||
@@ -121,8 +121,10 @@ class SearchChunk(object):
|
||||
# check if peak node already in chunk info
|
||||
if chunk_regions is not None:
|
||||
for i in chunk_regions:
|
||||
if i["region"][0] < peak_region[0] <= i["region"][1] or \
|
||||
i["region"][0] < peak_region[1] <= i["region"][1]:
|
||||
if (
|
||||
i["region"][0] < peak_region[0] <= i["region"][1]
|
||||
or i["region"][0] < peak_region[1] <= i["region"][1]
|
||||
):
|
||||
return None
|
||||
|
||||
active_node_num = [len(i) for i in active_node]
|
||||
@@ -146,9 +148,9 @@ class SearchChunk(object):
|
||||
region = i["region"]
|
||||
if chunk_region_start >= region[0] and chunk_region_end <= region[1]:
|
||||
return None
|
||||
elif (region[0] <= chunk_region_start <= region[1] and chunk_region_end > region[1]):
|
||||
elif region[0] <= chunk_region_start <= region[1] and chunk_region_end > region[1]:
|
||||
chunk_region_start = region[1] + 1
|
||||
elif (region[0] <= chunk_region_end <= region[1] and chunk_region_start < region[0]):
|
||||
elif region[0] <= chunk_region_end <= region[1] and chunk_region_start < region[0]:
|
||||
chunk_region_end = region[0] - 1
|
||||
return chunk_region_start, chunk_region_end
|
||||
|
||||
@@ -171,7 +173,7 @@ class SearchChunk(object):
|
||||
chunk_infos: possible regions found
|
||||
"""
|
||||
start_traces = input_trace[start_idx]
|
||||
if len(start_traces) > 1: # TODO need to be removed
|
||||
if len(start_traces) > 1: # TODO need to be removed
|
||||
return []
|
||||
end_trace = output_trace[end_idx]
|
||||
end_node = self.node_mgr.get_node_by_idx(end_idx)
|
||||
@@ -180,8 +182,9 @@ class SearchChunk(object):
|
||||
for end_dim, _ in enumerate(end_trace["indice"]):
|
||||
for start_node, start_trace in start_traces.items():
|
||||
for start_dim, _ in enumerate(start_trace["indice"]):
|
||||
if not self.trace_flow.check_region_start_end(start_node, start_dim, start_idx, end_node, end_dim,
|
||||
end_idx):
|
||||
if not self.trace_flow.check_region_start_end(
|
||||
start_node, start_dim, start_idx, end_node, end_dim, end_idx
|
||||
):
|
||||
continue
|
||||
# flow search
|
||||
chunk_info = self.trace_flow.flow_search(start_idx, start_dim, end_idx, end_dim)
|
||||
@@ -203,7 +206,7 @@ class SearchChunk(object):
|
||||
"""
|
||||
possible_chunk_region = []
|
||||
output_trace = copy.deepcopy(self.trace_indice.indice_trace_list)
|
||||
input_trace = [] # trace of a node's input nodes
|
||||
input_trace = [] # trace of a node's input nodes
|
||||
for _, n in enumerate(self.node_mgr.get_node_list()):
|
||||
cur_trace = {}
|
||||
for arg in n.args:
|
||||
@@ -215,7 +218,8 @@ class SearchChunk(object):
|
||||
for end_idx in range(peak_region[1], max_chunk_region[1] + 1):
|
||||
# skip non compute nodes
|
||||
if is_non_compute_node(self.node_mgr.get_node_by_idx(start_idx)) or is_non_compute_node(
|
||||
self.node_mgr.get_node_by_idx(end_idx)):
|
||||
self.node_mgr.get_node_by_idx(end_idx)
|
||||
):
|
||||
continue
|
||||
# select free dim
|
||||
chunk_info = self._find_chunk_info(input_trace, output_trace, start_idx, end_idx)
|
||||
@@ -279,15 +283,18 @@ class SearchChunk(object):
|
||||
chunk_infos.append(chunk_info)
|
||||
|
||||
mem_peak, _, active_node = self.estimate_memory.estimate_chunk_inference_mem(
|
||||
self.node_mgr.get_node_list(), chunk_infos)
|
||||
self.node_mgr.get_node_list(), chunk_infos
|
||||
)
|
||||
|
||||
if self.print_progress:
|
||||
get_logger().info("AutoChunk find chunk region %d = (%d, %d)" %
|
||||
(len(chunk_infos), chunk_info["region"][0], chunk_info["region"][1]))
|
||||
get_logger().info(
|
||||
"AutoChunk find chunk region %d = (%d, %d)"
|
||||
% (len(chunk_infos), chunk_info["region"][0], chunk_info["region"][1])
|
||||
)
|
||||
|
||||
if self.print_mem:
|
||||
self.print_mem = False
|
||||
self.estimate_memory.estimate_chunk_inference_mem(self.node_mgr.get_node_list(),
|
||||
chunk_infos,
|
||||
print_mem=True)
|
||||
self.estimate_memory.estimate_chunk_inference_mem(
|
||||
self.node_mgr.get_node_list(), chunk_infos, print_mem=True
|
||||
)
|
||||
return chunk_infos
|
||||
|
@@ -5,7 +5,6 @@ from .utils import NodeMgr, is_non_compute_node
|
||||
|
||||
|
||||
class SelectChunk(object):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
trace_indice: TraceIndice,
|
||||
@@ -20,7 +19,7 @@ class SelectChunk(object):
|
||||
self.node_mgr = node_mgr
|
||||
if max_memory is not None:
|
||||
self.stratge = "fit_memory"
|
||||
self.max_memory = max_memory # MB
|
||||
self.max_memory = max_memory # MB
|
||||
else:
|
||||
self.stratge = "min_memory"
|
||||
|
||||
@@ -57,16 +56,18 @@ class SelectChunk(object):
|
||||
cur_node_list, cur_region = self.reorder_graph.tmp_reorder(self.node_mgr.get_node_list(), cur_region)
|
||||
cur_chunk_infos = chunk_infos + [cur_region]
|
||||
cur_mem = self.estimate_memory.estimate_chunk_inference_mem(cur_node_list, cur_chunk_infos)[0]
|
||||
cur_chunk_region_peak = cur_mem[cur_region["region"][0]:cur_region["region"][1] + 1]
|
||||
cur_chunk_region_peak = cur_mem[cur_region["region"][0] : cur_region["region"][1] + 1]
|
||||
cur_chunk_region_max_peak = max(cur_chunk_region_peak)
|
||||
if cur_chunk_region_max_peak < self.max_memory:
|
||||
regions_dict.append({
|
||||
"chunk_info": region,
|
||||
"chunk_max_mem": cur_chunk_region_max_peak,
|
||||
"chunk_len": self._get_compute_node_num(region["region"][0], region["region"][1]),
|
||||
"reorder_chunk_info": cur_region,
|
||||
"reorder_node_list": cur_node_list,
|
||||
})
|
||||
regions_dict.append(
|
||||
{
|
||||
"chunk_info": region,
|
||||
"chunk_max_mem": cur_chunk_region_max_peak,
|
||||
"chunk_len": self._get_compute_node_num(region["region"][0], region["region"][1]),
|
||||
"reorder_chunk_info": cur_region,
|
||||
"reorder_node_list": cur_node_list,
|
||||
}
|
||||
)
|
||||
# no region found
|
||||
if len(regions_dict) == 0:
|
||||
raise RuntimeError("Search failed. Try a larger memory threshold.")
|
||||
@@ -90,13 +91,15 @@ class SelectChunk(object):
|
||||
chunk_size *= 2
|
||||
reorder_chunk_info["chunk_size"] = chunk_size
|
||||
cur_chunk_infos = chunk_infos + [reorder_chunk_info]
|
||||
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(chunk_region_dict["reorder_node_list"],
|
||||
cur_chunk_infos)[0]
|
||||
cur_chunk_max_mem = max(cur_mem_peak[reorder_chunk_info["region"][0]:reorder_chunk_info["region"][1] + 1])
|
||||
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(
|
||||
chunk_region_dict["reorder_node_list"], cur_chunk_infos
|
||||
)[0]
|
||||
cur_chunk_max_mem = max(cur_mem_peak[reorder_chunk_info["region"][0] : reorder_chunk_info["region"][1] + 1])
|
||||
# search exact size
|
||||
chunk_info = chunk_region_dict["chunk_info"]
|
||||
chunk_info["chunk_size"] = self._chunk_size_binary_search(chunk_size // 2, chunk_size, chunk_region_dict,
|
||||
chunk_infos)
|
||||
chunk_info["chunk_size"] = self._chunk_size_binary_search(
|
||||
chunk_size // 2, chunk_size, chunk_region_dict, chunk_infos
|
||||
)
|
||||
return chunk_info
|
||||
|
||||
def _chunk_size_binary_search(self, left, right, chunk_region_dict, chunk_infos):
|
||||
@@ -109,9 +112,10 @@ class SelectChunk(object):
|
||||
mid = int((left + right) / 2 + 0.5)
|
||||
chunk_info["chunk_size"] = mid
|
||||
cur_chunk_infos = chunk_infos + [chunk_info]
|
||||
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(chunk_region_dict["reorder_node_list"],
|
||||
cur_chunk_infos)[0]
|
||||
cur_chunk_max_mem = max(cur_mem_peak[chunk_info["region"][0]:chunk_info["region"][1] + 1])
|
||||
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(
|
||||
chunk_region_dict["reorder_node_list"], cur_chunk_infos
|
||||
)[0]
|
||||
cur_chunk_max_mem = max(cur_mem_peak[chunk_info["region"][0] : chunk_info["region"][1] + 1])
|
||||
if cur_chunk_max_mem >= self.max_memory:
|
||||
right = mid - gap
|
||||
else:
|
||||
@@ -139,8 +143,10 @@ class SelectChunk(object):
|
||||
return None
|
||||
|
||||
# get max possible chunk region
|
||||
max_possible_chunk_region = (min([i["region"][0] for i in possible_chunk_regions]),
|
||||
max([i["region"][1] for i in possible_chunk_regions]))
|
||||
max_possible_chunk_region = (
|
||||
min([i["region"][0] for i in possible_chunk_regions]),
|
||||
max([i["region"][1] for i in possible_chunk_regions]),
|
||||
)
|
||||
|
||||
# get mem for chunk region
|
||||
regions_dict_list = []
|
||||
@@ -149,15 +155,17 @@ class SelectChunk(object):
|
||||
cur_node_list, cur_region = self.reorder_graph.tmp_reorder(self.node_mgr.get_node_list(), cur_region)
|
||||
cur_chunk_infos = chunk_infos + [cur_region]
|
||||
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(cur_node_list, cur_chunk_infos)[0]
|
||||
cur_chunk_region_peak = cur_mem_peak[max_possible_chunk_region[0]:max_possible_chunk_region[1] + 1]
|
||||
cur_chunk_region_peak = cur_mem_peak[max_possible_chunk_region[0] : max_possible_chunk_region[1] + 1]
|
||||
cur_chunk_region_max_peak = max(cur_chunk_region_peak)
|
||||
regions_dict_list.append({
|
||||
"chunk_info": region,
|
||||
"chunk_max_mem": cur_chunk_region_max_peak,
|
||||
"chunk_len": self._get_compute_node_num(region["region"][0], region["region"][1]),
|
||||
"reorder_chunk_info": cur_region,
|
||||
"reorder_node_list": cur_node_list,
|
||||
})
|
||||
regions_dict_list.append(
|
||||
{
|
||||
"chunk_info": region,
|
||||
"chunk_max_mem": cur_chunk_region_max_peak,
|
||||
"chunk_len": self._get_compute_node_num(region["region"][0], region["region"][1]),
|
||||
"reorder_chunk_info": cur_region,
|
||||
"reorder_node_list": cur_node_list,
|
||||
}
|
||||
)
|
||||
|
||||
# select the min mem
|
||||
chunk_max_mem = [i["chunk_max_mem"] for i in regions_dict_list]
|
||||
@@ -175,7 +183,9 @@ class SelectChunk(object):
|
||||
return False
|
||||
for i in chunk_infos:
|
||||
region = i["region"]
|
||||
if not ((chunk_region_start > region[1] and chunk_region_end > region[1]) or
|
||||
(chunk_region_start < region[0] and chunk_region_end < region[0])):
|
||||
if not (
|
||||
(chunk_region_start > region[1] and chunk_region_end > region[1])
|
||||
or (chunk_region_start < region[0] and chunk_region_end < region[0])
|
||||
):
|
||||
return False
|
||||
return True
|
||||
|
@@ -16,7 +16,6 @@ from .utils import (
|
||||
|
||||
|
||||
class TraceFlow(object):
|
||||
|
||||
def __init__(self, trace_indice: TraceIndice, node_mgr: NodeMgr) -> None:
|
||||
self.trace_indice = trace_indice
|
||||
self.node_mgr = node_mgr
|
||||
@@ -151,7 +150,7 @@ class TraceFlow(object):
|
||||
return True
|
||||
|
||||
def _get_all_node_info(self, end_dim, start_idx, end_idx):
|
||||
cur_node_list = [self.node_mgr.get_node_by_idx(end_idx)] # start from the last node
|
||||
cur_node_list = [self.node_mgr.get_node_by_idx(end_idx)] # start from the last node
|
||||
all_node_info = {cur_node_list[0]: {"chunk_dim": end_dim, "fix_dim": []}}
|
||||
|
||||
while len(cur_node_list) > 0:
|
||||
@@ -266,7 +265,7 @@ class TraceFlow(object):
|
||||
maybe_prepose_nodes.sort(
|
||||
key=lambda x: self.node_mgr.find_node_idx(x),
|
||||
reverse=True,
|
||||
) # from last node to first node
|
||||
) # from last node to first node
|
||||
prepose_nodes = []
|
||||
# set every node as root, search its args, if all legal, turn root and args as prepose nodes
|
||||
while len(maybe_prepose_nodes) > 0:
|
||||
@@ -328,7 +327,8 @@ class TraceFlow(object):
|
||||
|
||||
def flow_search(self, start_idx, start_dim, end_idx, end_dim):
|
||||
inputs, outputs = find_chunk_compute_input_and_output_nodes(
|
||||
self.node_mgr.get_node_slice_by_idx(start_idx, end_idx + 1))
|
||||
self.node_mgr.get_node_slice_by_idx(start_idx, end_idx + 1)
|
||||
)
|
||||
|
||||
# get every node's chunk dim and fix dim
|
||||
all_node_info = self._get_all_node_info(end_dim, start_idx, end_idx)
|
||||
@@ -371,8 +371,9 @@ class TraceFlow(object):
|
||||
|
||||
return chunk_info
|
||||
|
||||
def _get_other_output_info(self, outputs: List[Node], start_idx: int, start_dim: int, end_idx: int, end_dim: int,
|
||||
chunk_info: Dict):
|
||||
def _get_other_output_info(
|
||||
self, outputs: List[Node], start_idx: int, start_dim: int, end_idx: int, end_dim: int, chunk_info: Dict
|
||||
):
|
||||
start_node = self.node_mgr.get_node_by_idx(start_idx)
|
||||
# loop all outputs
|
||||
for output in outputs:
|
||||
@@ -384,8 +385,8 @@ class TraceFlow(object):
|
||||
# skip non tensor
|
||||
if get_node_shape(output) is None:
|
||||
# log shape tensor
|
||||
if len(output.meta['fwd_out']) > 0 and isinstance(output.meta['fwd_out'][0], int):
|
||||
chunk_info["outputs_non_tensor"][output] = str(output.meta['fwd_out'])
|
||||
if len(output.meta["fwd_out"]) > 0 and isinstance(output.meta["fwd_out"][0], int):
|
||||
chunk_info["outputs_non_tensor"][output] = str(output.meta["fwd_out"])
|
||||
continue
|
||||
# loop every dim of outputs, try to find a legal one
|
||||
for output_dim in range(len(get_node_shape(output))):
|
||||
@@ -421,7 +422,8 @@ class TraceFlow(object):
|
||||
for k, v in new_all_node_info.items():
|
||||
if k in chunk_info["node_chunk_dim"]:
|
||||
chunk_info["node_chunk_dim"][k]["fix_dim"] = list(
|
||||
set(chunk_info["node_chunk_dim"][k]["fix_dim"] + v["fix_dim"]))
|
||||
set(chunk_info["node_chunk_dim"][k]["fix_dim"] + v["fix_dim"])
|
||||
)
|
||||
else:
|
||||
chunk_info["node_chunk_dim"][k] = v
|
||||
chunk_info["outputs"].append(output)
|
||||
@@ -443,8 +445,11 @@ class TraceFlow(object):
|
||||
if node.args[0] in chunk_info["inputs_non_chunk"]:
|
||||
continue
|
||||
reshape_args = flat_list(node.args[1:])
|
||||
if len(reshape_args) == 1 and get_node_shape(reshape_args[0]) is None and len(
|
||||
reshape_args[0].meta['fwd_out']) > 1:
|
||||
if (
|
||||
len(reshape_args) == 1
|
||||
and get_node_shape(reshape_args[0]) is None
|
||||
and len(reshape_args[0].meta["fwd_out"]) > 1
|
||||
):
|
||||
continue
|
||||
chunk_dim = chunk_info["node_chunk_dim"][node]["chunk_dim"]
|
||||
new_shape = ""
|
||||
@@ -462,16 +467,17 @@ class TraceFlow(object):
|
||||
chunk_info["reshape_size"] = reshape_size
|
||||
return chunk_info
|
||||
|
||||
def check_region_start_end(self, start_node: Node, start_dim: int, start_idx: int, end_node: Node, end_dim: int,
|
||||
end_idx: int) -> bool:
|
||||
def check_region_start_end(
|
||||
self, start_node: Node, start_dim: int, start_idx: int, end_node: Node, end_dim: int, end_idx: int
|
||||
) -> bool:
|
||||
"""
|
||||
check if region start and end is legal
|
||||
"""
|
||||
# dim cannot be None
|
||||
if (get_node_shape(end_node) is None or get_node_shape(start_node) is None):
|
||||
if get_node_shape(end_node) is None or get_node_shape(start_node) is None:
|
||||
return False
|
||||
# dim size cannot be 1
|
||||
if (get_node_shape(end_node)[end_dim] == 1 or get_node_shape(start_node)[start_dim] == 1):
|
||||
if get_node_shape(end_node)[end_dim] == 1 or get_node_shape(start_node)[start_dim] == 1:
|
||||
return False
|
||||
# must have users
|
||||
if len(end_node.users) == 0:
|
||||
|
@@ -1,5 +1,5 @@
|
||||
import copy
|
||||
from typing import Dict, List, Tuple
|
||||
from typing import Dict, List
|
||||
|
||||
from torch.fx.node import Node
|
||||
|
||||
@@ -412,7 +412,7 @@ class TraceIndice(object):
|
||||
node_idx (int)
|
||||
"""
|
||||
# get conv input
|
||||
assert node.kwargs['size'] is None
|
||||
assert node.kwargs["size"] is None
|
||||
assert len(get_node_shape(node)) == 4
|
||||
|
||||
# assign index
|
||||
@@ -826,7 +826,7 @@ class TraceIndice(object):
|
||||
# clear compute
|
||||
for dim_compute in trace["compute"]:
|
||||
for i in range(len(dim_compute) - 1, -1, -1):
|
||||
if (dim_compute[i] < trace_barrier and dim_compute[i] not in active_nodes):
|
||||
if dim_compute[i] < trace_barrier and dim_compute[i] not in active_nodes:
|
||||
dim_compute.pop(i)
|
||||
continue
|
||||
# clear source
|
||||
@@ -876,10 +876,24 @@ class TraceIndice(object):
|
||||
self._assign_matmul_indice(node, idx)
|
||||
elif "softmax" == node_name:
|
||||
self._assign_softmax_indice(node, idx)
|
||||
elif any(n == node_name for n in [
|
||||
"mul", "add", "sigmoid", "relu", "sub", "truediv", "pow", "dropout", "where", "tanh", "exp",
|
||||
"sin", "cos"
|
||||
]):
|
||||
elif any(
|
||||
n == node_name
|
||||
for n in [
|
||||
"mul",
|
||||
"add",
|
||||
"sigmoid",
|
||||
"relu",
|
||||
"sub",
|
||||
"truediv",
|
||||
"pow",
|
||||
"dropout",
|
||||
"where",
|
||||
"tanh",
|
||||
"exp",
|
||||
"sin",
|
||||
"cos",
|
||||
]
|
||||
):
|
||||
self._assign_elementwise_indice(node, idx)
|
||||
elif "einsum" == node_name:
|
||||
self._assign_einsum_indice(node, idx)
|
||||
@@ -920,7 +934,7 @@ class TraceIndice(object):
|
||||
else:
|
||||
raise NotImplementedError(node_name, "module not implemented yet!")
|
||||
elif node.op == "get_attr":
|
||||
self._assign_all_indice(node, idx) # get param
|
||||
self._assign_all_indice(node, idx) # get param
|
||||
elif node.op == "output":
|
||||
continue
|
||||
else:
|
||||
|
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Callable, Dict, Iterable, List, Tuple, Union
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
from torch.fx.node import Node
|
||||
|
||||
@@ -10,7 +10,6 @@ logger = get_dist_logger()
|
||||
|
||||
|
||||
class NodeMgr(object):
|
||||
|
||||
def __init__(self, nodes_list: List[Node]) -> None:
|
||||
self._node_list = nodes_list
|
||||
self._node_dict = {}
|
||||
@@ -174,16 +173,22 @@ def find_chunk_compute_input_and_output_nodes(nodes: List[Node]) -> Union[List,
|
||||
# we treat that input node as the input of the checkpoint function
|
||||
for node in nodes:
|
||||
for input_node in node._input_nodes.keys():
|
||||
if (input_node not in nodes and input_node not in input_nodes
|
||||
and not is_non_compute_node_except_placeholder(input_node)):
|
||||
if (
|
||||
input_node not in nodes
|
||||
and input_node not in input_nodes
|
||||
and not is_non_compute_node_except_placeholder(input_node)
|
||||
):
|
||||
input_nodes.append(input_node)
|
||||
|
||||
# if a node has a user node which is not in the node list
|
||||
# we treat that user node as the node receiving the current node output
|
||||
for node in nodes:
|
||||
for output_node in node.users.keys():
|
||||
if (output_node not in nodes and node not in output_nodes
|
||||
and not is_non_compute_node_except_placeholder_output(output_node)):
|
||||
if (
|
||||
output_node not in nodes
|
||||
and node not in output_nodes
|
||||
and not is_non_compute_node_except_placeholder_output(output_node)
|
||||
):
|
||||
output_nodes.append(node)
|
||||
|
||||
return input_nodes, output_nodes
|
||||
@@ -238,7 +243,10 @@ def find_tensor_shape_node(node_list: List[Node]) -> List[Node]:
|
||||
for node in node_list:
|
||||
if get_node_shape(node) is not None:
|
||||
out.append(node)
|
||||
elif len(node.meta['fwd_out']) > 0 and isinstance(node.meta['fwd_out'], list) and isinstance(
|
||||
node.meta['fwd_out'][0], int):
|
||||
elif (
|
||||
len(node.meta["fwd_out"]) > 0
|
||||
and isinstance(node.meta["fwd_out"], list)
|
||||
and isinstance(node.meta["fwd_out"][0], int)
|
||||
):
|
||||
out.append(node)
|
||||
return out
|
||||
|
Reference in New Issue
Block a user