[autochunk] support parsing blocks (#2506)

This commit is contained in:
oahzxl
2023-01-20 11:18:17 +08:00
committed by GitHub
parent 35c0c0006e
commit c04f183237
7 changed files with 314 additions and 22 deletions

View File

@@ -22,7 +22,7 @@ if CODEGEN_AVAILABLE:
from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg
from .search_chunk import SearchChunk
from .utils import delete_free_var_from_last_use, find_idx_by_name, get_node_shape
from .utils import delete_free_var_from_last_use, find_idx_by_name, get_logger, get_node_shape
def _gen_chunk_slice_dim(chunk_dim: int, chunk_indice_name: str, shape: List) -> str:
@@ -276,11 +276,17 @@ if CODEGEN_AVAILABLE:
class AutoChunkCodeGen(CodeGen):
def __init__(self, meta_graph, max_memory=None, print_mem=False):
def __init__(self,
meta_graph,
max_memory: int = None,
print_mem: bool = False,
print_progress: bool = False) -> None:
super().__init__()
# find the chunk regions
self.search_chunk = SearchChunk(meta_graph, max_memory, print_mem)
self.search_chunk = SearchChunk(meta_graph, max_memory, print_mem, print_progress)
self.chunk_infos = self.search_chunk.search_region()
if print_progress:
get_logger().info("AutoChunk start codegen")
def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace) -> PythonCode:
free_vars: List[str] = []

View File

@@ -43,6 +43,8 @@ class EstimateMemory(object):
delete_node = []
if user.op not in ("output",):
nodes_to_delete = user_to_last_uses.get(user, [])
if len(user.users) == 0:
nodes_to_delete.append(user)
if to_keep is not None:
keep_list = []
for n in nodes_to_delete:
@@ -135,6 +137,8 @@ class EstimateMemory(object):
if user.op in ("placeholder", "output"):
return 0
nodes_to_delete = user_to_last_uses.get(user, [])
if len(user.users) == 0:
nodes_to_delete.append(user)
delete_size = 0
for n in nodes_to_delete:
if n.name in chunk_inputs_names:
@@ -294,3 +298,26 @@ class EstimateMemory(object):
# param_memory = parameter_size(gm)
# all_memory = act_memory + param_memory
return act_memory_peak_log, act_memory_after_node_log, active_node_list_log
def get_active_nodes(self, node_list: List) -> List:
"""
Get active nodes for every node
Args:
node_list (List): _description_
Returns:
active_node_list_log (List): active nodes of every node. active nodes refer to
nodes generated but not deleted.
"""
active_node_list = []
active_node_list_log = []
user_to_last_uses = self._get_last_usr(node_list)
user_to_last_uses_no_free_var = self._get_last_usr(node_list)
delete_free_var_from_last_use(user_to_last_uses_no_free_var)
for _, node in enumerate(node_list):
# log active node, only effective without chunk
self._add_active_node(node, active_node_list)
self._remove_deactive_node(node, user_to_last_uses, active_node_list)
active_node_list_log.append(copy.deepcopy(active_node_list))
return active_node_list_log

View File

@@ -8,7 +8,7 @@ from .reorder_graph import ReorderGraph
from .select_chunk import SelectChunk
from .trace_flow import TraceFlow
from .trace_indice import TraceIndice
from .utils import get_node_shape, is_non_compute_node, is_non_compute_node_except_placeholder
from .utils import get_logger, get_node_shape, is_non_compute_node, is_non_compute_node_except_placeholder
class SearchChunk(object):
@@ -40,14 +40,14 @@ class SearchChunk(object):
print_mem (bool): print estimated memory
"""
def __init__(self, gm, max_memory=None, print_mem=False) -> None:
self.gm = gm
def __init__(self, gm, max_memory=None, print_mem=False, print_progress=False) -> None:
self.print_mem = print_mem
self.print_progress = print_progress
self.trace_indice = TraceIndice(list(gm.graph.nodes))
self.trace_indice.trace_indice()
self.estimate_memory = EstimateMemory()
self._init_trace()
self.trace_flow = TraceFlow(self.trace_indice)
self.reorder_graph = ReorderGraph(self.trace_indice)
self.estimate_memory = EstimateMemory()
self.select_chunk = SelectChunk(
self.trace_indice,
self.estimate_memory,
@@ -55,7 +55,33 @@ class SearchChunk(object):
max_memory=max_memory,
)
def _find_peak_node(self, mem_peak):
def _init_trace(self) -> None:
"""
find the max trace range for every node
reduce the computation complexity of trace_indice
"""
# find all max ranges
active_nodes = self.estimate_memory.get_active_nodes(self.trace_indice.node_list)
cur_node_idx = len(self._get_free_var_idx())
max_chunk_region_list = []
while True:
max_chunk_region = self._search_max_chunk_region(active_nodes, cur_node_idx)
cur_node_idx = max_chunk_region[1]
if cur_node_idx == len(active_nodes) - 1:
break
max_chunk_region_list.append(max_chunk_region)
# nothing to limit for the first range
max_chunk_region_list = max_chunk_region_list[1:]
max_chunk_region_list[0] = (0, max_chunk_region_list[0][1])
# set trace range and do the trace
if self.print_progress:
get_logger().info("AutoChunk start tracing indice")
self.trace_indice.set_trace_range(max_chunk_region_list, active_nodes)
self.trace_indice.trace_indice()
def _find_peak_node(self, mem_peak: List) -> int:
max_value = max(mem_peak)
max_idx = mem_peak.index(max_value)
return max_idx
@@ -73,7 +99,7 @@ class SearchChunk(object):
free_var_idx.append(idx)
return free_var_idx
def _search_max_chunk_region(self, active_node: List, peak_node: Node, chunk_regions: List) -> Tuple:
def _search_max_chunk_region(self, active_node: List, peak_node_idx: int, chunk_regions: List = None) -> Tuple:
"""
Search max chunk region according to peak memory node
@@ -81,7 +107,7 @@ class SearchChunk(object):
Args:
active_node (List): active node status for every node
peak_node (Node): peak memory node
peak_node_idx (int): peak memory node idx
chunk_regions (List): chunk region infos
Returns:
@@ -97,7 +123,7 @@ class SearchChunk(object):
# from peak_node to free_var
inside_flag = False
chunk_region_start = free_var_num
for i in range(peak_node, -1, -1):
for i in range(peak_node_idx, -1, -1):
if active_node_num[i] <= threshold:
inside_flag = True
if inside_flag and active_node_num[i] > threshold:
@@ -107,21 +133,23 @@ class SearchChunk(object):
# from peak_node to len-2
inside_flag = False
chunk_region_end = len(active_node) - 1
for i in range(peak_node, len(active_node)):
for i in range(peak_node_idx, len(active_node)):
if active_node_num[i] <= threshold:
inside_flag = True
if inside_flag and active_node_num[i] > threshold:
chunk_region_end = i
break
for i in chunk_regions:
region = i["region"]
if chunk_region_start >= region[0] and chunk_region_end <= region[1]:
return None
elif (region[0] <= chunk_region_start <= region[1] and chunk_region_end > region[1]):
chunk_region_start = region[1] + 1
elif (region[0] <= chunk_region_end <= region[1] and chunk_region_start < region[0]):
chunk_region_end = region[0] - 1
# avoid chunk regions overlap
if chunk_regions is not None:
for i in chunk_regions:
region = i["region"]
if chunk_region_start >= region[0] and chunk_region_end <= region[1]:
return None
elif (region[0] <= chunk_region_start <= region[1] and chunk_region_end > region[1]):
chunk_region_start = region[1] + 1
elif (region[0] <= chunk_region_end <= region[1] and chunk_region_start < region[0]):
chunk_region_end = region[0] - 1
return chunk_region_start, chunk_region_end
def _find_chunk_info(self, input_trace, output_trace, start_idx, end_idx) -> List:
@@ -154,6 +182,9 @@ class SearchChunk(object):
# dim size cannot be 1
if (get_node_shape(end_node)[end_dim] == 1 or get_node_shape(start_node)[start_dim] == 1):
continue
# must have users
if len(end_node.users) == 0:
continue
# check index source align
if not self.trace_flow.check_index_source(start_dim, start_node, start_idx, end_dim, end_node):
continue
@@ -253,6 +284,9 @@ class SearchChunk(object):
Returns:
chunk_infos (Dict)
"""
if self.print_progress:
get_logger().info("AutoChunk start searching chunk regions")
chunk_infos = []
(
init_mem_peak,
@@ -272,6 +306,11 @@ class SearchChunk(object):
_,
active_node,
) = self.estimate_memory.estimate_chunk_inference_mem(self.trace_indice.node_list, chunk_infos)
if self.print_progress:
get_logger().info("AutoChunk find chunk region %d = (%d, %d)" %
(len(chunk_infos), chunk_info["region"][0], chunk_info["region"][1]))
if self._stop_search(init_mem_peak, mem_peak):
break
if self.print_mem:

View File

@@ -281,7 +281,10 @@ class TraceFlow(object):
if chunk_dim is not None:
user_source = self.trace_indice._find_source_trace_from_node(user)[chunk_dim]
if input_node_idx in user_source:
input_dict[user_idx] = user_source[input_node_idx]
if get_node_shape(input_node)[user_source[input_node_idx][0]] == 1:
input_dict[user_idx] = [None]
else:
input_dict[user_idx] = user_source[input_node_idx]
else:
return None, None
if len(input_dict) == 0:

View File

@@ -33,6 +33,8 @@ class TraceIndice(object):
self.indice_trace_list = self._init_indice_trace_list()
self.indice_view_list = {}
self.indice_count = -1
self.trace_range = []
self.active_node_list = []
def _init_indice_trace_list(self):
indice_trace_list = []
@@ -48,6 +50,10 @@ class TraceIndice(object):
indice_trace_list.append(cur_trace)
return indice_trace_list
def set_trace_range(self, trace_range: List, active_node_list: List) -> None:
self.trace_range = trace_range
self.active_node_list = active_node_list
def _add_indice(self):
"""
Update the count and return it. To record the idx number.
@@ -493,6 +499,9 @@ class TraceIndice(object):
new_dim_num = sum([1 if str(i) == "None" else 0 for i in node_args])
for _ in range(new_dim_num):
self._del_dim(node_idx, 0)
delete_dim_num = sum([1 if str(i) == "0" else 0 for i in node_args])
for _ in range(delete_dim_num):
self._add_dim(node_idx, 0)
self._assign_indice_as_input(node, node_idx)
for _, node_arg in enumerate(node_args):
@@ -513,6 +522,9 @@ class TraceIndice(object):
elif "None" == node_arg_str:
self._add_dim(node_idx, new_idx_count)
new_idx_count += 1
elif "0" == node_arg_str:
self._del_dim(node_idx, new_idx_count)
origin_idx_count += 1
else:
raise NotImplementedError()
@@ -596,6 +608,37 @@ class TraceIndice(object):
}
self.indice_view_list[node] = view_dict
def _clear_trace(self, node_idx: int) -> None:
"""
clear too far trace to speed up computation
"""
trace_range = None
for i in range(len(self.trace_range)):
if self.trace_range[i][1] == node_idx:
trace_range = (self.trace_range[i][0], self.trace_range[i][1])
break
if self.trace_range[i][1] > node_idx:
break
if trace_range is None:
return
active_nodes = self.active_node_list[trace_range[0]:trace_range[1] + 1]
active_nodes = set(flat_list(active_nodes))
active_nodes = [find_idx_by_name(i, self.node_list) for i in active_nodes]
for i in range(trace_range[0], trace_range[1] + 1):
trace = self.indice_trace_list[i]
# clear compute
for dim_compute in trace["compute"]:
for i in range(len(dim_compute) - 1, -1, -1):
if dim_compute[i] < trace_range[0] and dim_compute[i] not in active_nodes:
dim_compute.pop(i)
continue
# clear source
for dim_source in trace["source"]:
for k in list(dim_source.keys()):
if k < trace_range[0] and k not in active_nodes:
dim_source.pop(k)
def trace_indice(self):
for idx, node in enumerate(self.node_list):
if node.op == "placeholder":
@@ -655,3 +698,6 @@ class TraceIndice(object):
continue
else:
raise NotImplementedError(node.op, "op not implemented yet!")
# limit trace range
self._clear_trace(idx)

View File

@@ -2,6 +2,14 @@ from typing import Any, Callable, Dict, Iterable, List, Tuple
from torch.fx.node import Node
from colossalai.logging import get_dist_logger
logger = get_dist_logger()
def get_logger():
return logger
def flat_list(inputs: Any) -> List:
"""