[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -9,7 +9,18 @@ from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABL
AUTOCHUNK_AVAILABLE = CODEGEN_AVAILABLE and is_compatible_with_meta()
if AUTOCHUNK_AVAILABLE:
from torch.fx.graph import CodeGen, PythonCode, _custom_builtins, _CustomBuiltin, _format_target, _is_from_torch, _Namespace, _origin_type_map, inplace_methods, magic_methods
from torch.fx.graph import (
CodeGen,
PythonCode,
_custom_builtins,
_CustomBuiltin,
_format_target,
_is_from_torch,
_Namespace,
_origin_type_map,
inplace_methods,
magic_methods,
)
from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg
@@ -64,14 +75,21 @@ def _gen_loop_start(chunk_input: List[Node], chunk_output: List[Node], chunk_out
for i in range(len(chunk_output)):
shape_str = str(list(get_node_shape(chunk_output[i])))
if get_node_name(chunk_output[i]) in ["split", "unbind"]:
tensor_str = "torch.empty(%s, dtype=%s.dtype, device=%s.device), " % (shape_str, input_node.name,
input_node.name)
tensor_str = tensor_str * len(chunk_output[i].meta['tensor_meta'])
tensor_str = "torch.empty(%s, dtype=%s.dtype, device=%s.device), " % (
shape_str,
input_node.name,
input_node.name,
)
tensor_str = tensor_str * len(chunk_output[i].meta["tensor_meta"])
tensor_str = "[" + tensor_str[:-2] + "]"
context += "%s = %s; " % (chunk_output[i].name, tensor_str)
else:
context += "%s = torch.empty(%s, dtype=%s.dtype, device=%s.device); " % (chunk_output[i].name, shape_str,
input_node.name, input_node.name)
context += "%s = torch.empty(%s, dtype=%s.dtype, device=%s.device); " % (
chunk_output[i].name,
shape_str,
input_node.name,
input_node.name,
)
out_shape = get_node_shape(chunk_output[0])
chunk_shape = out_shape[chunk_output_dim[0]]
@@ -79,8 +97,14 @@ def _gen_loop_start(chunk_input: List[Node], chunk_output: List[Node], chunk_out
return context
def _gen_loop_end(chunk_inputs: List[Node], chunk_non_compute_inputs: List[Node], node_list: List[Node],
chunk_outputs_idx: int, chunk_outputs_non_tensor: List[Node], search_chunk: SearchChunk) -> str:
def _gen_loop_end(
chunk_inputs: List[Node],
chunk_non_compute_inputs: List[Node],
node_list: List[Node],
chunk_outputs_idx: int,
chunk_outputs_non_tensor: List[Node],
search_chunk: SearchChunk,
) -> str:
"""
Generate chunk loop end
@@ -148,8 +172,10 @@ def _replace_new_tensor_like_shape(
chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node]["chunk_dim"]
if get_node_shape(meta_node)[chunk_dim] != 1:
source_node = meta_node.args[0].args[0]
if (source_node not in chunk_infos[region_idx]["node_chunk_dim"]
or chunk_infos[region_idx]["node_chunk_dim"][source_node]["chunk_dim"] is None):
if (
source_node not in chunk_infos[region_idx]["node_chunk_dim"]
or chunk_infos[region_idx]["node_chunk_dim"][source_node]["chunk_dim"] is None
):
chunk_slice = _gen_chunk_slice_dim(chunk_dim, "chunk_idx", get_node_shape(node))
body[-1] = _replace_name(body[-1], node.args[0].name, node.args[0].name + chunk_slice)
return body
@@ -203,11 +229,12 @@ def _add_node_slice(
# outputs node
else:
if chunk_node.name == node.name or (chunk_node.name in [i.name for i in node.all_input_nodes]):
chunk_slice = _gen_chunk_slice_dim(chunk_nodes_dim[region_idx][chunk_node_idx], "chunk_idx",
get_node_shape(chunk_node))
chunk_slice = _gen_chunk_slice_dim(
chunk_nodes_dim[region_idx][chunk_node_idx], "chunk_idx", get_node_shape(chunk_node)
)
if get_node_name(chunk_node) in ["split", "unbind"]:
split_chunk_slice = ""
for i in range(len(chunk_node.meta['tensor_meta'])):
for i in range(len(chunk_node.meta["tensor_meta"])):
split_chunk_slice += "%s[%d]%s, " % (chunk_node.name, i, chunk_slice)
split_chunk_slice = split_chunk_slice[:-2]
body[-1] = _replace_name(body[-1], chunk_node.name, split_chunk_slice)
@@ -216,13 +243,15 @@ def _add_node_slice(
return body
def emit_code_with_chunk(body: List[str],
nodes: Iterable[Node],
emit_node_func: Callable,
delete_unused_value_func: Callable,
search_chunk: SearchChunk,
chunk_infos: List,
eval_mem: bool = False):
def emit_code_with_chunk(
body: List[str],
nodes: Iterable[Node],
emit_node_func: Callable,
delete_unused_value_func: Callable,
search_chunk: SearchChunk,
chunk_infos: List,
eval_mem: bool = False,
):
"""
Emit code with chunk according to chunk_infos.
@@ -244,9 +273,9 @@ def emit_code_with_chunk(body: List[str],
chunk_ends = [i["region"][1] for i in chunk_infos]
# chunk inputs
chunk_inputs = [i["inputs"] for i in chunk_infos] # input with chunk
chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos] # input without chunk
chunk_inputs_dim = [i["inputs_dim"] for i in chunk_infos] # input chunk dim
chunk_inputs = [i["inputs"] for i in chunk_infos] # input with chunk
chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos] # input without chunk
chunk_inputs_dim = [i["inputs_dim"] for i in chunk_infos] # input chunk dim
chunk_inputs_names = [j.name for i in chunk_inputs for j in i] + [j.name for i in chunk_inputs_non_chunk for j in i]
# chunk outputs
@@ -275,7 +304,8 @@ def emit_code_with_chunk(body: List[str],
chunk_outputs[region_idx],
chunk_outputs_dim[region_idx],
chunk_infos[region_idx]["chunk_size"],
))
)
)
if within_chunk_region:
emit_node_func(node, body)
@@ -294,7 +324,8 @@ def emit_code_with_chunk(body: List[str],
if eval_mem:
body.append(
" if chunk_idx == 0:\n print('%s', torch.cuda.max_memory_allocated() / 1024**2 - init_memory); torch.cuda.reset_peak_memory_stats()\n"
% (node.name))
% (node.name)
)
else:
emit_node_func(node, body)
if node_idx not in chunk_inputs:
@@ -302,13 +333,21 @@ def emit_code_with_chunk(body: List[str],
if eval_mem:
body.append(
"print('%s', torch.cuda.max_memory_allocated() / 1024**2 - init_memory); torch.cuda.reset_peak_memory_stats()\n"
% (node.name))
% (node.name)
)
# generate chunk region end
if node_idx in chunk_ends:
body.append(
_gen_loop_end(chunk_inputs[region_idx], chunk_inputs_non_chunk[region_idx], node_list,
chunk_ends[region_idx], chunk_outputs_non_tensor[region_idx], search_chunk))
_gen_loop_end(
chunk_inputs[region_idx],
chunk_inputs_non_chunk[region_idx],
node_list,
chunk_ends[region_idx],
chunk_outputs_non_tensor[region_idx],
search_chunk,
)
)
within_chunk_region = False
node_idx += 1
@@ -317,13 +356,14 @@ def emit_code_with_chunk(body: List[str],
if AUTOCHUNK_AVAILABLE:
class AutoChunkCodeGen(CodeGen):
def __init__(self,
meta_graph,
max_memory: int = None,
print_mem: bool = False,
print_progress: bool = False,
eval_mem: bool = False) -> None:
def __init__(
self,
meta_graph,
max_memory: int = None,
print_mem: bool = False,
print_progress: bool = False,
eval_mem: bool = False,
) -> None:
super().__init__()
self.eval_mem = eval_mem
# find the chunk regions
@@ -349,7 +389,7 @@ if AUTOCHUNK_AVAILABLE:
Returns: the global name that should be used to reference 'obj' in generated source.
"""
if (_is_from_torch(obj) and obj != torch.device): # to support registering torch.device
if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device
# HACK: workaround for how torch custom ops are registered. We
# can't import them like normal modules so they must retain their
# fully qualified name.
@@ -402,7 +442,6 @@ if AUTOCHUNK_AVAILABLE:
return add_global(typename, o)
def _format_args(args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> str:
def _get_repr(arg):
# Handle NamedTuples (if it has `_fields`) via add_global.
if isinstance(arg, tuple) and hasattr(arg, "_fields"):
@@ -457,10 +496,10 @@ if AUTOCHUNK_AVAILABLE:
# NOTE: we add a variable to distinguish body and ckpt_func
def emit_node(node: Node, body):
maybe_type_annotation = ("" if node.type is None else f" : {type_repr(node.type)}")
maybe_type_annotation = "" if node.type is None else f" : {type_repr(node.type)}"
if node.op == "placeholder":
assert isinstance(node.target, str)
maybe_default_arg = ("" if not node.args else f" = {repr(node.args[0])}")
maybe_default_arg = "" if not node.args else f" = {repr(node.args[0])}"
free_vars.append(f"{node.target}{maybe_type_annotation}{maybe_default_arg}")
raw_name = node.target.replace("*", "")
if raw_name != repr(node):
@@ -470,42 +509,56 @@ if AUTOCHUNK_AVAILABLE:
assert isinstance(node.target, str)
body.append(
f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}"
f"({_format_args(node.args[1:], node.kwargs)})")
f"({_format_args(node.args[1:], node.kwargs)})"
)
return
elif node.op == "call_function":
assert callable(node.target)
# pretty print operators
if (node.target.__module__ == "_operator" and node.target.__name__ in magic_methods):
if node.target.__module__ == "_operator" and node.target.__name__ in magic_methods:
assert isinstance(node.args, tuple)
body.append(f"{repr(node)}{maybe_type_annotation} = "
f"{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}")
body.append(
f"{repr(node)}{maybe_type_annotation} = "
f"{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}"
)
return
# pretty print inplace operators; required for jit.script to work properly
# not currently supported in normal FX graphs, but generated by torchdynamo
if (node.target.__module__ == "_operator" and node.target.__name__ in inplace_methods):
body.append(f"{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; "
f"{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}")
if node.target.__module__ == "_operator" and node.target.__name__ in inplace_methods:
body.append(
f"{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; "
f"{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}"
)
return
qualified_name = _get_qualified_name(node.target)
global_name = add_global(qualified_name, node.target)
# special case for getattr: node.args could be 2-argument or 3-argument
# 2-argument: attribute access; 3-argument: fall through to attrib function call with default value
if (global_name == "getattr" and isinstance(node.args, tuple) and isinstance(node.args[1], str)
and node.args[1].isidentifier() and len(node.args) == 2):
if (
global_name == "getattr"
and isinstance(node.args, tuple)
and isinstance(node.args[1], str)
and node.args[1].isidentifier()
and len(node.args) == 2
):
body.append(
f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}")
f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}"
)
return
body.append(
f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})")
f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})"
)
if node.meta.get("is_wrapped", False):
wrapped_fns.setdefault(global_name)
return
elif node.op == "call_module":
assert isinstance(node.target, str)
body.append(f"{repr(node)}{maybe_type_annotation} = "
f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})")
body.append(
f"{repr(node)}{maybe_type_annotation} = "
f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})"
)
return
elif node.op == "get_attr":
assert isinstance(node.target, str)
@@ -523,8 +576,9 @@ if AUTOCHUNK_AVAILABLE:
# if any node has a list of labels for activation_checkpoint, we
# will use nested type of activation checkpoint codegen
emit_code_with_chunk(body, nodes, emit_node, delete_unused_values, self.search_chunk, self.chunk_infos,
self.eval_mem)
emit_code_with_chunk(
body, nodes, emit_node, delete_unused_values, self.search_chunk, self.chunk_infos, self.eval_mem
)
if len(body) == 0:
# If the Graph has no non-placeholder nodes, no lines for the body

View File

@@ -1,11 +1,8 @@
import copy
from typing import Any, Callable, Dict, Iterable, List, Tuple
from typing import Dict, List
import torch
from torch.fx.node import Node
from colossalai.fx.profiler import activation_size, parameter_size
from .utils import NodeMgr, get_node_shape, is_non_memory_node
@@ -62,12 +59,9 @@ class EstimateMemory(object):
delete_node_dict[node] = max(node_user_idx)
return delete_node_dict
def _remove_deactive_node(self,
user_idx: int,
user: Node,
active_nodes: List,
delete_node_dict: List,
kept_nodes: List = None) -> None:
def _remove_deactive_node(
self, user_idx: int, user: Node, active_nodes: List, delete_node_dict: List, kept_nodes: List = None
) -> None:
"""
remove deactivate nodes from active nodes
"""
@@ -169,7 +163,7 @@ class EstimateMemory(object):
use_chunk = True if chunk_infos is not None else False
chunk_within = False
chunk_region_idx = None
chunk_ratio = 1 # use it to estimate chunk mem
chunk_ratio = 1 # use it to estimate chunk mem
chunk_inputs_all = []
if use_chunk:
@@ -184,7 +178,6 @@ class EstimateMemory(object):
chunk_sizes = [i["chunk_size"] if "chunk_size" in i else 1 for i in chunk_infos]
for idx, node in enumerate(node_mgr.get_node_list()):
# if node in chunk start nodes, change chunk ratio and add chunk_tensor
if use_chunk and idx in chunk_starts:
chunk_within = True
@@ -193,8 +186,9 @@ class EstimateMemory(object):
# determine chunk ratio for current node
if chunk_within:
chunk_ratio = self._get_chunk_ratio(node, chunk_node_dim[chunk_region_idx],
chunk_sizes[chunk_region_idx])
chunk_ratio = self._get_chunk_ratio(
node, chunk_node_dim[chunk_region_idx], chunk_sizes[chunk_region_idx]
)
# add current node as active node
self._add_active_node(node, active_nodes, chunk_ratio)
@@ -222,7 +216,7 @@ class EstimateMemory(object):
# if node in chunk end nodes, restore chunk settings
if use_chunk and idx in chunk_ends:
self._remove_deactive_node(idx, node, active_nodes, delete_node_dict) # dont provide kept nodes now
self._remove_deactive_node(idx, node, active_nodes, delete_node_dict) # dont provide kept nodes now
chunk_within = False
chunk_ratio = 1
chunk_region_idx = None

View File

@@ -8,7 +8,7 @@ from .reorder_graph import ReorderGraph
from .select_chunk import SelectChunk
from .trace_flow import TraceFlow
from .trace_indice import TraceIndice
from .utils import NodeMgr, get_logger, get_node_shape, is_non_compute_node, is_non_compute_node_except_placeholder
from .utils import NodeMgr, get_logger, is_non_compute_node, is_non_compute_node_except_placeholder
class SearchChunk(object):
@@ -121,8 +121,10 @@ class SearchChunk(object):
# check if peak node already in chunk info
if chunk_regions is not None:
for i in chunk_regions:
if i["region"][0] < peak_region[0] <= i["region"][1] or \
i["region"][0] < peak_region[1] <= i["region"][1]:
if (
i["region"][0] < peak_region[0] <= i["region"][1]
or i["region"][0] < peak_region[1] <= i["region"][1]
):
return None
active_node_num = [len(i) for i in active_node]
@@ -146,9 +148,9 @@ class SearchChunk(object):
region = i["region"]
if chunk_region_start >= region[0] and chunk_region_end <= region[1]:
return None
elif (region[0] <= chunk_region_start <= region[1] and chunk_region_end > region[1]):
elif region[0] <= chunk_region_start <= region[1] and chunk_region_end > region[1]:
chunk_region_start = region[1] + 1
elif (region[0] <= chunk_region_end <= region[1] and chunk_region_start < region[0]):
elif region[0] <= chunk_region_end <= region[1] and chunk_region_start < region[0]:
chunk_region_end = region[0] - 1
return chunk_region_start, chunk_region_end
@@ -171,7 +173,7 @@ class SearchChunk(object):
chunk_infos: possible regions found
"""
start_traces = input_trace[start_idx]
if len(start_traces) > 1: # TODO need to be removed
if len(start_traces) > 1: # TODO need to be removed
return []
end_trace = output_trace[end_idx]
end_node = self.node_mgr.get_node_by_idx(end_idx)
@@ -180,8 +182,9 @@ class SearchChunk(object):
for end_dim, _ in enumerate(end_trace["indice"]):
for start_node, start_trace in start_traces.items():
for start_dim, _ in enumerate(start_trace["indice"]):
if not self.trace_flow.check_region_start_end(start_node, start_dim, start_idx, end_node, end_dim,
end_idx):
if not self.trace_flow.check_region_start_end(
start_node, start_dim, start_idx, end_node, end_dim, end_idx
):
continue
# flow search
chunk_info = self.trace_flow.flow_search(start_idx, start_dim, end_idx, end_dim)
@@ -203,7 +206,7 @@ class SearchChunk(object):
"""
possible_chunk_region = []
output_trace = copy.deepcopy(self.trace_indice.indice_trace_list)
input_trace = [] # trace of a node's input nodes
input_trace = [] # trace of a node's input nodes
for _, n in enumerate(self.node_mgr.get_node_list()):
cur_trace = {}
for arg in n.args:
@@ -215,7 +218,8 @@ class SearchChunk(object):
for end_idx in range(peak_region[1], max_chunk_region[1] + 1):
# skip non compute nodes
if is_non_compute_node(self.node_mgr.get_node_by_idx(start_idx)) or is_non_compute_node(
self.node_mgr.get_node_by_idx(end_idx)):
self.node_mgr.get_node_by_idx(end_idx)
):
continue
# select free dim
chunk_info = self._find_chunk_info(input_trace, output_trace, start_idx, end_idx)
@@ -279,15 +283,18 @@ class SearchChunk(object):
chunk_infos.append(chunk_info)
mem_peak, _, active_node = self.estimate_memory.estimate_chunk_inference_mem(
self.node_mgr.get_node_list(), chunk_infos)
self.node_mgr.get_node_list(), chunk_infos
)
if self.print_progress:
get_logger().info("AutoChunk find chunk region %d = (%d, %d)" %
(len(chunk_infos), chunk_info["region"][0], chunk_info["region"][1]))
get_logger().info(
"AutoChunk find chunk region %d = (%d, %d)"
% (len(chunk_infos), chunk_info["region"][0], chunk_info["region"][1])
)
if self.print_mem:
self.print_mem = False
self.estimate_memory.estimate_chunk_inference_mem(self.node_mgr.get_node_list(),
chunk_infos,
print_mem=True)
self.estimate_memory.estimate_chunk_inference_mem(
self.node_mgr.get_node_list(), chunk_infos, print_mem=True
)
return chunk_infos

View File

@@ -5,7 +5,6 @@ from .utils import NodeMgr, is_non_compute_node
class SelectChunk(object):
def __init__(
self,
trace_indice: TraceIndice,
@@ -20,7 +19,7 @@ class SelectChunk(object):
self.node_mgr = node_mgr
if max_memory is not None:
self.stratge = "fit_memory"
self.max_memory = max_memory # MB
self.max_memory = max_memory # MB
else:
self.stratge = "min_memory"
@@ -57,16 +56,18 @@ class SelectChunk(object):
cur_node_list, cur_region = self.reorder_graph.tmp_reorder(self.node_mgr.get_node_list(), cur_region)
cur_chunk_infos = chunk_infos + [cur_region]
cur_mem = self.estimate_memory.estimate_chunk_inference_mem(cur_node_list, cur_chunk_infos)[0]
cur_chunk_region_peak = cur_mem[cur_region["region"][0]:cur_region["region"][1] + 1]
cur_chunk_region_peak = cur_mem[cur_region["region"][0] : cur_region["region"][1] + 1]
cur_chunk_region_max_peak = max(cur_chunk_region_peak)
if cur_chunk_region_max_peak < self.max_memory:
regions_dict.append({
"chunk_info": region,
"chunk_max_mem": cur_chunk_region_max_peak,
"chunk_len": self._get_compute_node_num(region["region"][0], region["region"][1]),
"reorder_chunk_info": cur_region,
"reorder_node_list": cur_node_list,
})
regions_dict.append(
{
"chunk_info": region,
"chunk_max_mem": cur_chunk_region_max_peak,
"chunk_len": self._get_compute_node_num(region["region"][0], region["region"][1]),
"reorder_chunk_info": cur_region,
"reorder_node_list": cur_node_list,
}
)
# no region found
if len(regions_dict) == 0:
raise RuntimeError("Search failed. Try a larger memory threshold.")
@@ -90,13 +91,15 @@ class SelectChunk(object):
chunk_size *= 2
reorder_chunk_info["chunk_size"] = chunk_size
cur_chunk_infos = chunk_infos + [reorder_chunk_info]
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(chunk_region_dict["reorder_node_list"],
cur_chunk_infos)[0]
cur_chunk_max_mem = max(cur_mem_peak[reorder_chunk_info["region"][0]:reorder_chunk_info["region"][1] + 1])
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(
chunk_region_dict["reorder_node_list"], cur_chunk_infos
)[0]
cur_chunk_max_mem = max(cur_mem_peak[reorder_chunk_info["region"][0] : reorder_chunk_info["region"][1] + 1])
# search exact size
chunk_info = chunk_region_dict["chunk_info"]
chunk_info["chunk_size"] = self._chunk_size_binary_search(chunk_size // 2, chunk_size, chunk_region_dict,
chunk_infos)
chunk_info["chunk_size"] = self._chunk_size_binary_search(
chunk_size // 2, chunk_size, chunk_region_dict, chunk_infos
)
return chunk_info
def _chunk_size_binary_search(self, left, right, chunk_region_dict, chunk_infos):
@@ -109,9 +112,10 @@ class SelectChunk(object):
mid = int((left + right) / 2 + 0.5)
chunk_info["chunk_size"] = mid
cur_chunk_infos = chunk_infos + [chunk_info]
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(chunk_region_dict["reorder_node_list"],
cur_chunk_infos)[0]
cur_chunk_max_mem = max(cur_mem_peak[chunk_info["region"][0]:chunk_info["region"][1] + 1])
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(
chunk_region_dict["reorder_node_list"], cur_chunk_infos
)[0]
cur_chunk_max_mem = max(cur_mem_peak[chunk_info["region"][0] : chunk_info["region"][1] + 1])
if cur_chunk_max_mem >= self.max_memory:
right = mid - gap
else:
@@ -139,8 +143,10 @@ class SelectChunk(object):
return None
# get max possible chunk region
max_possible_chunk_region = (min([i["region"][0] for i in possible_chunk_regions]),
max([i["region"][1] for i in possible_chunk_regions]))
max_possible_chunk_region = (
min([i["region"][0] for i in possible_chunk_regions]),
max([i["region"][1] for i in possible_chunk_regions]),
)
# get mem for chunk region
regions_dict_list = []
@@ -149,15 +155,17 @@ class SelectChunk(object):
cur_node_list, cur_region = self.reorder_graph.tmp_reorder(self.node_mgr.get_node_list(), cur_region)
cur_chunk_infos = chunk_infos + [cur_region]
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(cur_node_list, cur_chunk_infos)[0]
cur_chunk_region_peak = cur_mem_peak[max_possible_chunk_region[0]:max_possible_chunk_region[1] + 1]
cur_chunk_region_peak = cur_mem_peak[max_possible_chunk_region[0] : max_possible_chunk_region[1] + 1]
cur_chunk_region_max_peak = max(cur_chunk_region_peak)
regions_dict_list.append({
"chunk_info": region,
"chunk_max_mem": cur_chunk_region_max_peak,
"chunk_len": self._get_compute_node_num(region["region"][0], region["region"][1]),
"reorder_chunk_info": cur_region,
"reorder_node_list": cur_node_list,
})
regions_dict_list.append(
{
"chunk_info": region,
"chunk_max_mem": cur_chunk_region_max_peak,
"chunk_len": self._get_compute_node_num(region["region"][0], region["region"][1]),
"reorder_chunk_info": cur_region,
"reorder_node_list": cur_node_list,
}
)
# select the min mem
chunk_max_mem = [i["chunk_max_mem"] for i in regions_dict_list]
@@ -175,7 +183,9 @@ class SelectChunk(object):
return False
for i in chunk_infos:
region = i["region"]
if not ((chunk_region_start > region[1] and chunk_region_end > region[1]) or
(chunk_region_start < region[0] and chunk_region_end < region[0])):
if not (
(chunk_region_start > region[1] and chunk_region_end > region[1])
or (chunk_region_start < region[0] and chunk_region_end < region[0])
):
return False
return True

View File

@@ -16,7 +16,6 @@ from .utils import (
class TraceFlow(object):
def __init__(self, trace_indice: TraceIndice, node_mgr: NodeMgr) -> None:
self.trace_indice = trace_indice
self.node_mgr = node_mgr
@@ -151,7 +150,7 @@ class TraceFlow(object):
return True
def _get_all_node_info(self, end_dim, start_idx, end_idx):
cur_node_list = [self.node_mgr.get_node_by_idx(end_idx)] # start from the last node
cur_node_list = [self.node_mgr.get_node_by_idx(end_idx)] # start from the last node
all_node_info = {cur_node_list[0]: {"chunk_dim": end_dim, "fix_dim": []}}
while len(cur_node_list) > 0:
@@ -266,7 +265,7 @@ class TraceFlow(object):
maybe_prepose_nodes.sort(
key=lambda x: self.node_mgr.find_node_idx(x),
reverse=True,
) # from last node to first node
) # from last node to first node
prepose_nodes = []
# set every node as root, search its args, if all legal, turn root and args as prepose nodes
while len(maybe_prepose_nodes) > 0:
@@ -328,7 +327,8 @@ class TraceFlow(object):
def flow_search(self, start_idx, start_dim, end_idx, end_dim):
inputs, outputs = find_chunk_compute_input_and_output_nodes(
self.node_mgr.get_node_slice_by_idx(start_idx, end_idx + 1))
self.node_mgr.get_node_slice_by_idx(start_idx, end_idx + 1)
)
# get every node's chunk dim and fix dim
all_node_info = self._get_all_node_info(end_dim, start_idx, end_idx)
@@ -371,8 +371,9 @@ class TraceFlow(object):
return chunk_info
def _get_other_output_info(self, outputs: List[Node], start_idx: int, start_dim: int, end_idx: int, end_dim: int,
chunk_info: Dict):
def _get_other_output_info(
self, outputs: List[Node], start_idx: int, start_dim: int, end_idx: int, end_dim: int, chunk_info: Dict
):
start_node = self.node_mgr.get_node_by_idx(start_idx)
# loop all outputs
for output in outputs:
@@ -384,8 +385,8 @@ class TraceFlow(object):
# skip non tensor
if get_node_shape(output) is None:
# log shape tensor
if len(output.meta['fwd_out']) > 0 and isinstance(output.meta['fwd_out'][0], int):
chunk_info["outputs_non_tensor"][output] = str(output.meta['fwd_out'])
if len(output.meta["fwd_out"]) > 0 and isinstance(output.meta["fwd_out"][0], int):
chunk_info["outputs_non_tensor"][output] = str(output.meta["fwd_out"])
continue
# loop every dim of outputs, try to find a legal one
for output_dim in range(len(get_node_shape(output))):
@@ -421,7 +422,8 @@ class TraceFlow(object):
for k, v in new_all_node_info.items():
if k in chunk_info["node_chunk_dim"]:
chunk_info["node_chunk_dim"][k]["fix_dim"] = list(
set(chunk_info["node_chunk_dim"][k]["fix_dim"] + v["fix_dim"]))
set(chunk_info["node_chunk_dim"][k]["fix_dim"] + v["fix_dim"])
)
else:
chunk_info["node_chunk_dim"][k] = v
chunk_info["outputs"].append(output)
@@ -443,8 +445,11 @@ class TraceFlow(object):
if node.args[0] in chunk_info["inputs_non_chunk"]:
continue
reshape_args = flat_list(node.args[1:])
if len(reshape_args) == 1 and get_node_shape(reshape_args[0]) is None and len(
reshape_args[0].meta['fwd_out']) > 1:
if (
len(reshape_args) == 1
and get_node_shape(reshape_args[0]) is None
and len(reshape_args[0].meta["fwd_out"]) > 1
):
continue
chunk_dim = chunk_info["node_chunk_dim"][node]["chunk_dim"]
new_shape = ""
@@ -462,16 +467,17 @@ class TraceFlow(object):
chunk_info["reshape_size"] = reshape_size
return chunk_info
def check_region_start_end(self, start_node: Node, start_dim: int, start_idx: int, end_node: Node, end_dim: int,
end_idx: int) -> bool:
def check_region_start_end(
self, start_node: Node, start_dim: int, start_idx: int, end_node: Node, end_dim: int, end_idx: int
) -> bool:
"""
check if region start and end is legal
"""
# dim cannot be None
if (get_node_shape(end_node) is None or get_node_shape(start_node) is None):
if get_node_shape(end_node) is None or get_node_shape(start_node) is None:
return False
# dim size cannot be 1
if (get_node_shape(end_node)[end_dim] == 1 or get_node_shape(start_node)[start_dim] == 1):
if get_node_shape(end_node)[end_dim] == 1 or get_node_shape(start_node)[start_dim] == 1:
return False
# must have users
if len(end_node.users) == 0:

View File

@@ -1,5 +1,5 @@
import copy
from typing import Dict, List, Tuple
from typing import Dict, List
from torch.fx.node import Node
@@ -412,7 +412,7 @@ class TraceIndice(object):
node_idx (int)
"""
# get conv input
assert node.kwargs['size'] is None
assert node.kwargs["size"] is None
assert len(get_node_shape(node)) == 4
# assign index
@@ -826,7 +826,7 @@ class TraceIndice(object):
# clear compute
for dim_compute in trace["compute"]:
for i in range(len(dim_compute) - 1, -1, -1):
if (dim_compute[i] < trace_barrier and dim_compute[i] not in active_nodes):
if dim_compute[i] < trace_barrier and dim_compute[i] not in active_nodes:
dim_compute.pop(i)
continue
# clear source
@@ -876,10 +876,24 @@ class TraceIndice(object):
self._assign_matmul_indice(node, idx)
elif "softmax" == node_name:
self._assign_softmax_indice(node, idx)
elif any(n == node_name for n in [
"mul", "add", "sigmoid", "relu", "sub", "truediv", "pow", "dropout", "where", "tanh", "exp",
"sin", "cos"
]):
elif any(
n == node_name
for n in [
"mul",
"add",
"sigmoid",
"relu",
"sub",
"truediv",
"pow",
"dropout",
"where",
"tanh",
"exp",
"sin",
"cos",
]
):
self._assign_elementwise_indice(node, idx)
elif "einsum" == node_name:
self._assign_einsum_indice(node, idx)
@@ -920,7 +934,7 @@ class TraceIndice(object):
else:
raise NotImplementedError(node_name, "module not implemented yet!")
elif node.op == "get_attr":
self._assign_all_indice(node, idx) # get param
self._assign_all_indice(node, idx) # get param
elif node.op == "output":
continue
else:

View File

@@ -1,4 +1,4 @@
from typing import Any, Callable, Dict, Iterable, List, Tuple, Union
from typing import Any, Dict, List, Union
from torch.fx.node import Node
@@ -10,7 +10,6 @@ logger = get_dist_logger()
class NodeMgr(object):
def __init__(self, nodes_list: List[Node]) -> None:
self._node_list = nodes_list
self._node_dict = {}
@@ -174,16 +173,22 @@ def find_chunk_compute_input_and_output_nodes(nodes: List[Node]) -> Union[List,
# we treat that input node as the input of the checkpoint function
for node in nodes:
for input_node in node._input_nodes.keys():
if (input_node not in nodes and input_node not in input_nodes
and not is_non_compute_node_except_placeholder(input_node)):
if (
input_node not in nodes
and input_node not in input_nodes
and not is_non_compute_node_except_placeholder(input_node)
):
input_nodes.append(input_node)
# if a node has a user node which is not in the node list
# we treat that user node as the node receiving the current node output
for node in nodes:
for output_node in node.users.keys():
if (output_node not in nodes and node not in output_nodes
and not is_non_compute_node_except_placeholder_output(output_node)):
if (
output_node not in nodes
and node not in output_nodes
and not is_non_compute_node_except_placeholder_output(output_node)
):
output_nodes.append(node)
return input_nodes, output_nodes
@@ -238,7 +243,10 @@ def find_tensor_shape_node(node_list: List[Node]) -> List[Node]:
for node in node_list:
if get_node_shape(node) is not None:
out.append(node)
elif len(node.meta['fwd_out']) > 0 and isinstance(node.meta['fwd_out'], list) and isinstance(
node.meta['fwd_out'][0], int):
elif (
len(node.meta["fwd_out"]) > 0
and isinstance(node.meta["fwd_out"], list)
and isinstance(node.meta["fwd_out"][0], int)
):
out.append(node)
return out