[autochunk] support multi outputs chunk search (#2538)

Support multi outputs chunk search. Previously we only support single output chunk search. It is more flexible and improve performance by a large margin. For transformer, we reduce memory by 40% than previous search strategy.

1. rewrite search strategy to support multi outputs chunk search
2. fix many, many bugs
3. update tests
This commit is contained in:
oahzxl
2023-02-01 13:18:51 +08:00
committed by GitHub
parent f477a14f4a
commit 05671fcb42
14 changed files with 428 additions and 258 deletions

View File

@@ -4,9 +4,10 @@ from torch.fx.node import Node
from .trace_indice import TraceIndice
from .utils import (
NodeMgr,
find_chunk_all_input_nodes,
find_chunk_compute_input_and_output_nodes,
find_idx_by_name,
find_tensor_shape_node,
flat_list,
get_node_name,
get_node_shape,
@@ -16,8 +17,9 @@ from .utils import (
class TraceFlow(object):
def __init__(self, trace_indice: TraceIndice) -> None:
def __init__(self, trace_indice: TraceIndice, node_mgr: NodeMgr) -> None:
self.trace_indice = trace_indice
self.node_mgr = node_mgr
def check_index_source(self, start_dim, start_node, start_idx, end_dim, end_node):
"""
@@ -31,7 +33,8 @@ class TraceFlow(object):
Returns:
bool: True if check pass
"""
start_node_idx = find_idx_by_name(start_node.name, self.trace_indice.node_list)
# we use start_node_idx instead of real chunk index
start_node_idx = self.node_mgr.find_node_idx(start_node)
end_node_trace = self.trace_indice._find_trace_from_node(end_node)
end_node_trace_source = end_node_trace["source"][end_dim]
sorted_source = sorted(end_node_trace_source.items(), key=lambda d: d[0], reverse=True)
@@ -39,7 +42,7 @@ class TraceFlow(object):
if node_idx == start_node_idx and start_dim in node_dim:
return True
# it means we meet a node outside the loop, and the node is not input node
if node_idx < start_idx:
if node_idx < start_node_idx:
return False
return False
@@ -61,29 +64,12 @@ class TraceFlow(object):
return False
return True
def get_node_chunk_dim(self, node_from, node_from_dim, node_to):
node_from_source = self.trace_indice._find_source_trace_from_node(node_from)
dim_source = node_from_source[node_from_dim]
node_to_idx = find_idx_by_name(node_to.name, self.trace_indice.node_list)
for k, v in dim_source.items():
if k == node_to_idx:
return v
return None
def _find_inherit_dim(self, input_node, input_dim, node):
input_node_idx = find_idx_by_name(input_node.name, self.trace_indice.node_list)
node_trace_source = self.trace_indice._find_source_trace_from_node(node)
for node_dim in range(len(get_node_shape(node))):
if (input_node_idx in node_trace_source[node_dim]
and input_dim[0] in node_trace_source[node_dim][input_node_idx]):
return node_dim
return None
def _assgin_single_node_flow(
self,
arg_node: Node,
start_idx: int,
end_idx: int,
cur_node: Node,
cur_node_dim: int,
cur_node_compute: Dict,
cur_node_source: Dict,
@@ -109,7 +95,7 @@ class TraceFlow(object):
Returns:
bool: True if this node can be added to the flow, vice versa.
"""
arg_idx = find_idx_by_name(arg_node.name, self.trace_indice.node_list)
arg_idx = self.node_mgr.find_node_idx(arg_node)
# arg in chunk range or be inputs
if not (start_idx <= arg_idx < end_idx):
return True
@@ -126,6 +112,11 @@ class TraceFlow(object):
# chunk dim should be None if shape size is 1
if get_node_shape(arg_node)[arg_dim] == 1:
arg_dim = None
# chunk shape should equal cur node
elif get_node_shape(arg_node)[arg_dim] != 1:
if cur_node_dim is not None and get_node_shape(cur_node)[cur_node_dim] != 1:
if get_node_shape(arg_node)[arg_dim] != get_node_shape(cur_node)[cur_node_dim]:
return False
else:
arg_dim = None
@@ -150,7 +141,7 @@ class TraceFlow(object):
return True
def _get_all_node_info(self, end_dim, start_idx, end_idx):
cur_node_list = [self.trace_indice.node_list[end_idx]] # start from the last node
cur_node_list = [self.node_mgr.get_node_by_idx(end_idx)] # start from the last node
all_node_info = {cur_node_list[0]: {"chunk_dim": end_dim, "fix_dim": []}}
while len(cur_node_list) > 0:
@@ -178,6 +169,7 @@ class TraceFlow(object):
arg,
start_idx,
end_idx,
cur_node,
cur_node_chunk_dim,
cur_node_compute,
cur_node_source,
@@ -194,7 +186,7 @@ class TraceFlow(object):
for arg in arg_list:
if get_node_shape(arg) is None:
continue
if not (start_idx <= find_idx_by_name(arg.name, self.trace_indice.node_list) < end_idx):
if not (start_idx <= self.node_mgr.find_node_idx(arg) < end_idx):
continue
arg_chunk_dim = all_node_info[arg]["chunk_dim"]
arg_fix_dim = all_node_info[arg]["fix_dim"]
@@ -232,7 +224,7 @@ class TraceFlow(object):
remove_inputs = []
for input_node in inputs:
input_dict = {}
input_node_idx = find_idx_by_name(input_node.name, self.trace_indice.node_list)
input_node_idx = self.node_mgr.find_node_idx(input_node)
for user in input_node.users.keys():
# skip non compute
if is_non_compute_node(user):
@@ -240,7 +232,7 @@ class TraceFlow(object):
# untraced node, mostly non compute
if user not in all_node_info:
continue
user_idx = find_idx_by_name(user.name, self.trace_indice.node_list)
user_idx = self.node_mgr.find_node_idx(user)
if start_idx <= user_idx <= end_idx:
chunk_dim = all_node_info[user]["chunk_dim"]
if chunk_dim is not None:
@@ -262,7 +254,7 @@ class TraceFlow(object):
inputs.remove(i)
return inputs, inputs_dim
def _get_prepose_nodes(self, all_node_info: Dict, start_idx: int, end_idx: int) -> List[Node]:
def _get_prepose_nodes(self, all_node_info: Dict, start_idx: int, end_idx: int, chunk_info) -> List[Node]:
"""
get all useless nodes in chunk region and prepose them
@@ -279,8 +271,11 @@ class TraceFlow(object):
for node, node_info in all_node_info.items():
if node_info["chunk_dim"] is None:
maybe_prepose_nodes.append(node)
for node in self.node_mgr.get_node_slice_by_idx(start_idx, end_idx):
if node not in all_node_info and node not in chunk_info["outputs"]:
maybe_prepose_nodes.append(node)
maybe_prepose_nodes.sort(
key=lambda x: find_idx_by_name(x.name, self.trace_indice.node_list),
key=lambda x: self.node_mgr.find_node_idx(x),
reverse=True,
) # from last node to first node
prepose_nodes = []
@@ -303,8 +298,7 @@ class TraceFlow(object):
if type(cur_prepose_node_arg) != type(cur_prepose_node):
continue
# out of loop
if not (start_idx <= find_idx_by_name(cur_prepose_node_arg.name, self.trace_indice.node_list) <
end_idx):
if not (start_idx <= self.node_mgr.find_node_idx(cur_prepose_node_arg) < end_idx):
continue
# compute op in loop
elif cur_prepose_node_arg in all_node_info:
@@ -328,13 +322,12 @@ class TraceFlow(object):
if n in maybe_prepose_nodes:
maybe_prepose_nodes.remove(n)
# sort by index
prepose_nodes.sort(key=lambda x: find_idx_by_name(x.name, self.trace_indice.node_list))
return prepose_nodes
prepose_nodes.sort(key=lambda x: self.node_mgr.find_node_idx(x))
chunk_info["args"]["prepose_nodes"] = prepose_nodes
def _get_non_chunk_inputs(self, chunk_info, start_idx, end_idx):
# we need to log input nodes to avoid deleteing them in the loop
chunk_node_list = self.trace_indice.node_list[start_idx:end_idx + 1]
chunk_node_list = self.node_mgr.get_node_slice_by_idx(start_idx, end_idx + 1)
# also need to get some prepose node's arg out of non_chunk_inputs
for n in chunk_info["args"]["prepose_nodes"]:
chunk_node_list.remove(n)
@@ -345,34 +338,41 @@ class TraceFlow(object):
return chunk_info
def flow_search(self, start_idx, start_dim, end_idx, end_dim):
inputs, outputs = find_chunk_compute_input_and_output_nodes(self.trace_indice.node_list[start_idx:end_idx + 1])
# only single ouput
if len(outputs) > 1:
return None
inputs, outputs = find_chunk_compute_input_and_output_nodes(
self.node_mgr.get_node_slice_by_idx(start_idx, end_idx + 1))
# get every node's chunk dim and fix dim
all_node_info = self._get_all_node_info(end_dim, start_idx, end_idx)
if all_node_info is None:
return None
# get input nodes' chunk dim
inputs, inputs_dim = self._get_input_nodes_dim(inputs, start_idx, end_idx, all_node_info)
if inputs is None:
return None
chunk_info = {
"region": (start_idx, end_idx),
"inputs": inputs,
"inputs": [],
"inputs_non_chunk": [],
"inputs_dim": inputs_dim,
"outputs": outputs,
"outputs_dim": end_dim,
"inputs_dim": [],
"outputs": [self.node_mgr.get_node_by_idx(end_idx)],
"outputs_non_tensor": {},
"outputs_dim": [end_dim],
"node_chunk_dim": all_node_info,
"args": {},
}
# find chunk info for other outputs
if len(find_tensor_shape_node(outputs)) > 1:
chunk_info = self._get_other_output_info(outputs, start_idx, start_dim, end_idx, end_dim, chunk_info)
if chunk_info is None:
return None
# get input nodes' chunk dim
inputs, inputs_dim = self._get_input_nodes_dim(inputs, start_idx, end_idx, all_node_info)
if inputs is None:
return None
chunk_info["inputs"] = inputs
chunk_info["inputs_dim"] = inputs_dim
# move useless nodes ahead of loop
chunk_info["args"]["prepose_nodes"] = self._get_prepose_nodes(all_node_info, start_idx, end_idx)
self._get_prepose_nodes(all_node_info, start_idx, end_idx, chunk_info)
# find non chunk inputs
chunk_info = self._get_non_chunk_inputs(chunk_info, start_idx, end_idx)
@@ -382,6 +382,63 @@ class TraceFlow(object):
return chunk_info
def _get_other_output_info(self, outputs: List[Node], start_idx: int, start_dim: int, end_idx: int, end_dim: int,
chunk_info: Dict):
start_node = self.node_mgr.get_node_by_idx(start_idx)
# loop all outputs
for output in outputs:
output_legal = False
output_idx = self.node_mgr.find_node_idx(output)
# skip the origin output
if output_idx == end_idx:
continue
# skip non tensor
if get_node_shape(output) is None:
# log shape tensor
if len(output.meta['fwd_out']) > 0 and isinstance(output.meta['fwd_out'][0], int):
chunk_info["outputs_non_tensor"][output] = str(output.meta['fwd_out'])
continue
# loop every dim of outputs, try to find a legal one
for output_dim in range(len(get_node_shape(output))):
if not self.check_region_start_end(start_node, start_dim, start_idx, output, output_dim, output_idx):
continue
new_all_node_info = self._get_all_node_info(output_dim, start_idx, output_idx)
if new_all_node_info is None:
continue
# check node info legal
if self._update_chunk_info(chunk_info, new_all_node_info, output, output_dim) == True:
output_legal = True
break
# not legal
if output_legal == False:
return None
return chunk_info
def _update_chunk_info(self, chunk_info: Dict, new_all_node_info: Dict, output: Node, output_dim: int) -> bool:
"""
check if there is conflict between new node info and old chunk info. If not, update old chunk info
"""
# check if conflict
overlap_flag = False
for k, v in new_all_node_info.items():
if k in chunk_info["node_chunk_dim"]:
overlap_flag = True
if chunk_info["node_chunk_dim"][k]["chunk_dim"] != v["chunk_dim"]:
return False
# if no overlap, we just consider them as prepose nodes, instead of new output
if overlap_flag == False:
return True
# update chunk info
for k, v in new_all_node_info.items():
if k in chunk_info["node_chunk_dim"]:
chunk_info["node_chunk_dim"][k]["fix_dim"] = list(
set(chunk_info["node_chunk_dim"][k]["fix_dim"] + v["fix_dim"]))
else:
chunk_info["node_chunk_dim"][k] = v
chunk_info["outputs"].append(output)
chunk_info["outputs_dim"].append(output_dim)
return True
def _reassgin_reshape_size(self, chunk_info):
"""
Some shape args in reshape may have changed due to chunk
@@ -389,10 +446,17 @@ class TraceFlow(object):
"""
chunk_region = chunk_info["region"]
reshape_size = {}
chunk_shape = get_node_shape(chunk_info["outputs"][0])[chunk_info["outputs_dim"]]
for node in self.trace_indice.node_list[chunk_region[0]:chunk_region[1] + 1]:
chunk_shape = get_node_shape(chunk_info["outputs"][0])[chunk_info["outputs_dim"][0]]
for node in self.node_mgr.get_node_slice_by_idx(chunk_region[0], chunk_region[1] + 1):
if any(i == get_node_name(node) for i in ["reshape", "view"]):
if node in chunk_info["args"]["prepose_nodes"]:
continue
if node.args[0] in chunk_info["inputs_non_chunk"]:
continue
reshape_args = flat_list(node.args[1:])
if len(reshape_args) == 1 and get_node_shape(reshape_args[0]) is None and len(
reshape_args[0].meta['fwd_out']) > 1:
continue
chunk_dim = chunk_info["node_chunk_dim"][node]["chunk_dim"]
new_shape = ""
for reshape_arg_dim, reshape_arg in enumerate(reshape_args):
@@ -409,45 +473,8 @@ class TraceFlow(object):
chunk_info["reshape_size"] = reshape_size
return chunk_info
def find_chunk_info(self, input_trace, output_trace, start_idx, end_idx) -> List:
"""
Find chunk info for a region.
We are given the region start and region end, and need to find out all chunk info for it.
We first loop every dim of start node and end node, to see if we can find dim pair,
which is linked in a flow and not computed.
If found, we then search flow in the whole region to find out all chunk infos.
Args:
input_trace (List): node's input trace in region
output_trace (List): node's output trace in region
start_idx (int): region start node index
end_idx (int): region end node index
Returns:
chunk_infos: possible regions found
"""
start_traces = input_trace[start_idx]
if len(start_traces) > 1: # TODO need to be removed
return []
end_trace = output_trace[end_idx]
end_node = self.trace_indice.node_list[end_idx]
chunk_infos = []
for end_dim, _ in enumerate(end_trace["indice"]):
for start_node, start_trace in start_traces.items():
for start_dim, _ in enumerate(start_trace["indice"]):
if not self._check_region_start_end(start_node, start_dim, start_idx, end_node, end_dim, end_idx):
continue
# flow search
chunk_info = self.flow_search(start_idx, start_dim, end_idx, end_dim)
if chunk_info is None:
continue
chunk_infos.append(chunk_info)
return chunk_infos
def _check_region_start_end(self, start_node: Node, start_dim: int, start_idx: int, end_node: Node, end_dim: int,
end_idx: int) -> bool:
def check_region_start_end(self, start_node: Node, start_dim: int, start_idx: int, end_node: Node, end_dim: int,
end_idx: int) -> bool:
"""
check if region start and end is legal
"""