mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 11:32:10 +00:00
[autochunk] support multi outputs chunk search (#2538)
Support multi outputs chunk search. Previously we only support single output chunk search. It is more flexible and improve performance by a large margin. For transformer, we reduce memory by 40% than previous search strategy. 1. rewrite search strategy to support multi outputs chunk search 2. fix many, many bugs 3. update tests
This commit is contained in:
@@ -4,9 +4,10 @@ from torch.fx.node import Node
|
||||
|
||||
from .trace_indice import TraceIndice
|
||||
from .utils import (
|
||||
NodeMgr,
|
||||
find_chunk_all_input_nodes,
|
||||
find_chunk_compute_input_and_output_nodes,
|
||||
find_idx_by_name,
|
||||
find_tensor_shape_node,
|
||||
flat_list,
|
||||
get_node_name,
|
||||
get_node_shape,
|
||||
@@ -16,8 +17,9 @@ from .utils import (
|
||||
|
||||
class TraceFlow(object):
|
||||
|
||||
def __init__(self, trace_indice: TraceIndice) -> None:
|
||||
def __init__(self, trace_indice: TraceIndice, node_mgr: NodeMgr) -> None:
|
||||
self.trace_indice = trace_indice
|
||||
self.node_mgr = node_mgr
|
||||
|
||||
def check_index_source(self, start_dim, start_node, start_idx, end_dim, end_node):
|
||||
"""
|
||||
@@ -31,7 +33,8 @@ class TraceFlow(object):
|
||||
Returns:
|
||||
bool: True if check pass
|
||||
"""
|
||||
start_node_idx = find_idx_by_name(start_node.name, self.trace_indice.node_list)
|
||||
# we use start_node_idx instead of real chunk index
|
||||
start_node_idx = self.node_mgr.find_node_idx(start_node)
|
||||
end_node_trace = self.trace_indice._find_trace_from_node(end_node)
|
||||
end_node_trace_source = end_node_trace["source"][end_dim]
|
||||
sorted_source = sorted(end_node_trace_source.items(), key=lambda d: d[0], reverse=True)
|
||||
@@ -39,7 +42,7 @@ class TraceFlow(object):
|
||||
if node_idx == start_node_idx and start_dim in node_dim:
|
||||
return True
|
||||
# it means we meet a node outside the loop, and the node is not input node
|
||||
if node_idx < start_idx:
|
||||
if node_idx < start_node_idx:
|
||||
return False
|
||||
return False
|
||||
|
||||
@@ -61,29 +64,12 @@ class TraceFlow(object):
|
||||
return False
|
||||
return True
|
||||
|
||||
def get_node_chunk_dim(self, node_from, node_from_dim, node_to):
|
||||
node_from_source = self.trace_indice._find_source_trace_from_node(node_from)
|
||||
dim_source = node_from_source[node_from_dim]
|
||||
node_to_idx = find_idx_by_name(node_to.name, self.trace_indice.node_list)
|
||||
for k, v in dim_source.items():
|
||||
if k == node_to_idx:
|
||||
return v
|
||||
return None
|
||||
|
||||
def _find_inherit_dim(self, input_node, input_dim, node):
|
||||
input_node_idx = find_idx_by_name(input_node.name, self.trace_indice.node_list)
|
||||
node_trace_source = self.trace_indice._find_source_trace_from_node(node)
|
||||
for node_dim in range(len(get_node_shape(node))):
|
||||
if (input_node_idx in node_trace_source[node_dim]
|
||||
and input_dim[0] in node_trace_source[node_dim][input_node_idx]):
|
||||
return node_dim
|
||||
return None
|
||||
|
||||
def _assgin_single_node_flow(
|
||||
self,
|
||||
arg_node: Node,
|
||||
start_idx: int,
|
||||
end_idx: int,
|
||||
cur_node: Node,
|
||||
cur_node_dim: int,
|
||||
cur_node_compute: Dict,
|
||||
cur_node_source: Dict,
|
||||
@@ -109,7 +95,7 @@ class TraceFlow(object):
|
||||
Returns:
|
||||
bool: True if this node can be added to the flow, vice versa.
|
||||
"""
|
||||
arg_idx = find_idx_by_name(arg_node.name, self.trace_indice.node_list)
|
||||
arg_idx = self.node_mgr.find_node_idx(arg_node)
|
||||
# arg in chunk range or be inputs
|
||||
if not (start_idx <= arg_idx < end_idx):
|
||||
return True
|
||||
@@ -126,6 +112,11 @@ class TraceFlow(object):
|
||||
# chunk dim should be None if shape size is 1
|
||||
if get_node_shape(arg_node)[arg_dim] == 1:
|
||||
arg_dim = None
|
||||
# chunk shape should equal cur node
|
||||
elif get_node_shape(arg_node)[arg_dim] != 1:
|
||||
if cur_node_dim is not None and get_node_shape(cur_node)[cur_node_dim] != 1:
|
||||
if get_node_shape(arg_node)[arg_dim] != get_node_shape(cur_node)[cur_node_dim]:
|
||||
return False
|
||||
else:
|
||||
arg_dim = None
|
||||
|
||||
@@ -150,7 +141,7 @@ class TraceFlow(object):
|
||||
return True
|
||||
|
||||
def _get_all_node_info(self, end_dim, start_idx, end_idx):
|
||||
cur_node_list = [self.trace_indice.node_list[end_idx]] # start from the last node
|
||||
cur_node_list = [self.node_mgr.get_node_by_idx(end_idx)] # start from the last node
|
||||
all_node_info = {cur_node_list[0]: {"chunk_dim": end_dim, "fix_dim": []}}
|
||||
|
||||
while len(cur_node_list) > 0:
|
||||
@@ -178,6 +169,7 @@ class TraceFlow(object):
|
||||
arg,
|
||||
start_idx,
|
||||
end_idx,
|
||||
cur_node,
|
||||
cur_node_chunk_dim,
|
||||
cur_node_compute,
|
||||
cur_node_source,
|
||||
@@ -194,7 +186,7 @@ class TraceFlow(object):
|
||||
for arg in arg_list:
|
||||
if get_node_shape(arg) is None:
|
||||
continue
|
||||
if not (start_idx <= find_idx_by_name(arg.name, self.trace_indice.node_list) < end_idx):
|
||||
if not (start_idx <= self.node_mgr.find_node_idx(arg) < end_idx):
|
||||
continue
|
||||
arg_chunk_dim = all_node_info[arg]["chunk_dim"]
|
||||
arg_fix_dim = all_node_info[arg]["fix_dim"]
|
||||
@@ -232,7 +224,7 @@ class TraceFlow(object):
|
||||
remove_inputs = []
|
||||
for input_node in inputs:
|
||||
input_dict = {}
|
||||
input_node_idx = find_idx_by_name(input_node.name, self.trace_indice.node_list)
|
||||
input_node_idx = self.node_mgr.find_node_idx(input_node)
|
||||
for user in input_node.users.keys():
|
||||
# skip non compute
|
||||
if is_non_compute_node(user):
|
||||
@@ -240,7 +232,7 @@ class TraceFlow(object):
|
||||
# untraced node, mostly non compute
|
||||
if user not in all_node_info:
|
||||
continue
|
||||
user_idx = find_idx_by_name(user.name, self.trace_indice.node_list)
|
||||
user_idx = self.node_mgr.find_node_idx(user)
|
||||
if start_idx <= user_idx <= end_idx:
|
||||
chunk_dim = all_node_info[user]["chunk_dim"]
|
||||
if chunk_dim is not None:
|
||||
@@ -262,7 +254,7 @@ class TraceFlow(object):
|
||||
inputs.remove(i)
|
||||
return inputs, inputs_dim
|
||||
|
||||
def _get_prepose_nodes(self, all_node_info: Dict, start_idx: int, end_idx: int) -> List[Node]:
|
||||
def _get_prepose_nodes(self, all_node_info: Dict, start_idx: int, end_idx: int, chunk_info) -> List[Node]:
|
||||
"""
|
||||
get all useless nodes in chunk region and prepose them
|
||||
|
||||
@@ -279,8 +271,11 @@ class TraceFlow(object):
|
||||
for node, node_info in all_node_info.items():
|
||||
if node_info["chunk_dim"] is None:
|
||||
maybe_prepose_nodes.append(node)
|
||||
for node in self.node_mgr.get_node_slice_by_idx(start_idx, end_idx):
|
||||
if node not in all_node_info and node not in chunk_info["outputs"]:
|
||||
maybe_prepose_nodes.append(node)
|
||||
maybe_prepose_nodes.sort(
|
||||
key=lambda x: find_idx_by_name(x.name, self.trace_indice.node_list),
|
||||
key=lambda x: self.node_mgr.find_node_idx(x),
|
||||
reverse=True,
|
||||
) # from last node to first node
|
||||
prepose_nodes = []
|
||||
@@ -303,8 +298,7 @@ class TraceFlow(object):
|
||||
if type(cur_prepose_node_arg) != type(cur_prepose_node):
|
||||
continue
|
||||
# out of loop
|
||||
if not (start_idx <= find_idx_by_name(cur_prepose_node_arg.name, self.trace_indice.node_list) <
|
||||
end_idx):
|
||||
if not (start_idx <= self.node_mgr.find_node_idx(cur_prepose_node_arg) < end_idx):
|
||||
continue
|
||||
# compute op in loop
|
||||
elif cur_prepose_node_arg in all_node_info:
|
||||
@@ -328,13 +322,12 @@ class TraceFlow(object):
|
||||
if n in maybe_prepose_nodes:
|
||||
maybe_prepose_nodes.remove(n)
|
||||
# sort by index
|
||||
prepose_nodes.sort(key=lambda x: find_idx_by_name(x.name, self.trace_indice.node_list))
|
||||
|
||||
return prepose_nodes
|
||||
prepose_nodes.sort(key=lambda x: self.node_mgr.find_node_idx(x))
|
||||
chunk_info["args"]["prepose_nodes"] = prepose_nodes
|
||||
|
||||
def _get_non_chunk_inputs(self, chunk_info, start_idx, end_idx):
|
||||
# we need to log input nodes to avoid deleteing them in the loop
|
||||
chunk_node_list = self.trace_indice.node_list[start_idx:end_idx + 1]
|
||||
chunk_node_list = self.node_mgr.get_node_slice_by_idx(start_idx, end_idx + 1)
|
||||
# also need to get some prepose node's arg out of non_chunk_inputs
|
||||
for n in chunk_info["args"]["prepose_nodes"]:
|
||||
chunk_node_list.remove(n)
|
||||
@@ -345,34 +338,41 @@ class TraceFlow(object):
|
||||
return chunk_info
|
||||
|
||||
def flow_search(self, start_idx, start_dim, end_idx, end_dim):
|
||||
inputs, outputs = find_chunk_compute_input_and_output_nodes(self.trace_indice.node_list[start_idx:end_idx + 1])
|
||||
# only single ouput
|
||||
if len(outputs) > 1:
|
||||
return None
|
||||
inputs, outputs = find_chunk_compute_input_and_output_nodes(
|
||||
self.node_mgr.get_node_slice_by_idx(start_idx, end_idx + 1))
|
||||
|
||||
# get every node's chunk dim and fix dim
|
||||
all_node_info = self._get_all_node_info(end_dim, start_idx, end_idx)
|
||||
if all_node_info is None:
|
||||
return None
|
||||
|
||||
# get input nodes' chunk dim
|
||||
inputs, inputs_dim = self._get_input_nodes_dim(inputs, start_idx, end_idx, all_node_info)
|
||||
if inputs is None:
|
||||
return None
|
||||
|
||||
chunk_info = {
|
||||
"region": (start_idx, end_idx),
|
||||
"inputs": inputs,
|
||||
"inputs": [],
|
||||
"inputs_non_chunk": [],
|
||||
"inputs_dim": inputs_dim,
|
||||
"outputs": outputs,
|
||||
"outputs_dim": end_dim,
|
||||
"inputs_dim": [],
|
||||
"outputs": [self.node_mgr.get_node_by_idx(end_idx)],
|
||||
"outputs_non_tensor": {},
|
||||
"outputs_dim": [end_dim],
|
||||
"node_chunk_dim": all_node_info,
|
||||
"args": {},
|
||||
}
|
||||
|
||||
# find chunk info for other outputs
|
||||
if len(find_tensor_shape_node(outputs)) > 1:
|
||||
chunk_info = self._get_other_output_info(outputs, start_idx, start_dim, end_idx, end_dim, chunk_info)
|
||||
if chunk_info is None:
|
||||
return None
|
||||
|
||||
# get input nodes' chunk dim
|
||||
inputs, inputs_dim = self._get_input_nodes_dim(inputs, start_idx, end_idx, all_node_info)
|
||||
if inputs is None:
|
||||
return None
|
||||
chunk_info["inputs"] = inputs
|
||||
chunk_info["inputs_dim"] = inputs_dim
|
||||
|
||||
# move useless nodes ahead of loop
|
||||
chunk_info["args"]["prepose_nodes"] = self._get_prepose_nodes(all_node_info, start_idx, end_idx)
|
||||
self._get_prepose_nodes(all_node_info, start_idx, end_idx, chunk_info)
|
||||
|
||||
# find non chunk inputs
|
||||
chunk_info = self._get_non_chunk_inputs(chunk_info, start_idx, end_idx)
|
||||
@@ -382,6 +382,63 @@ class TraceFlow(object):
|
||||
|
||||
return chunk_info
|
||||
|
||||
def _get_other_output_info(self, outputs: List[Node], start_idx: int, start_dim: int, end_idx: int, end_dim: int,
|
||||
chunk_info: Dict):
|
||||
start_node = self.node_mgr.get_node_by_idx(start_idx)
|
||||
# loop all outputs
|
||||
for output in outputs:
|
||||
output_legal = False
|
||||
output_idx = self.node_mgr.find_node_idx(output)
|
||||
# skip the origin output
|
||||
if output_idx == end_idx:
|
||||
continue
|
||||
# skip non tensor
|
||||
if get_node_shape(output) is None:
|
||||
# log shape tensor
|
||||
if len(output.meta['fwd_out']) > 0 and isinstance(output.meta['fwd_out'][0], int):
|
||||
chunk_info["outputs_non_tensor"][output] = str(output.meta['fwd_out'])
|
||||
continue
|
||||
# loop every dim of outputs, try to find a legal one
|
||||
for output_dim in range(len(get_node_shape(output))):
|
||||
if not self.check_region_start_end(start_node, start_dim, start_idx, output, output_dim, output_idx):
|
||||
continue
|
||||
new_all_node_info = self._get_all_node_info(output_dim, start_idx, output_idx)
|
||||
if new_all_node_info is None:
|
||||
continue
|
||||
# check node info legal
|
||||
if self._update_chunk_info(chunk_info, new_all_node_info, output, output_dim) == True:
|
||||
output_legal = True
|
||||
break
|
||||
# not legal
|
||||
if output_legal == False:
|
||||
return None
|
||||
return chunk_info
|
||||
|
||||
def _update_chunk_info(self, chunk_info: Dict, new_all_node_info: Dict, output: Node, output_dim: int) -> bool:
|
||||
"""
|
||||
check if there is conflict between new node info and old chunk info. If not, update old chunk info
|
||||
"""
|
||||
# check if conflict
|
||||
overlap_flag = False
|
||||
for k, v in new_all_node_info.items():
|
||||
if k in chunk_info["node_chunk_dim"]:
|
||||
overlap_flag = True
|
||||
if chunk_info["node_chunk_dim"][k]["chunk_dim"] != v["chunk_dim"]:
|
||||
return False
|
||||
# if no overlap, we just consider them as prepose nodes, instead of new output
|
||||
if overlap_flag == False:
|
||||
return True
|
||||
# update chunk info
|
||||
for k, v in new_all_node_info.items():
|
||||
if k in chunk_info["node_chunk_dim"]:
|
||||
chunk_info["node_chunk_dim"][k]["fix_dim"] = list(
|
||||
set(chunk_info["node_chunk_dim"][k]["fix_dim"] + v["fix_dim"]))
|
||||
else:
|
||||
chunk_info["node_chunk_dim"][k] = v
|
||||
chunk_info["outputs"].append(output)
|
||||
chunk_info["outputs_dim"].append(output_dim)
|
||||
return True
|
||||
|
||||
def _reassgin_reshape_size(self, chunk_info):
|
||||
"""
|
||||
Some shape args in reshape may have changed due to chunk
|
||||
@@ -389,10 +446,17 @@ class TraceFlow(object):
|
||||
"""
|
||||
chunk_region = chunk_info["region"]
|
||||
reshape_size = {}
|
||||
chunk_shape = get_node_shape(chunk_info["outputs"][0])[chunk_info["outputs_dim"]]
|
||||
for node in self.trace_indice.node_list[chunk_region[0]:chunk_region[1] + 1]:
|
||||
chunk_shape = get_node_shape(chunk_info["outputs"][0])[chunk_info["outputs_dim"][0]]
|
||||
for node in self.node_mgr.get_node_slice_by_idx(chunk_region[0], chunk_region[1] + 1):
|
||||
if any(i == get_node_name(node) for i in ["reshape", "view"]):
|
||||
if node in chunk_info["args"]["prepose_nodes"]:
|
||||
continue
|
||||
if node.args[0] in chunk_info["inputs_non_chunk"]:
|
||||
continue
|
||||
reshape_args = flat_list(node.args[1:])
|
||||
if len(reshape_args) == 1 and get_node_shape(reshape_args[0]) is None and len(
|
||||
reshape_args[0].meta['fwd_out']) > 1:
|
||||
continue
|
||||
chunk_dim = chunk_info["node_chunk_dim"][node]["chunk_dim"]
|
||||
new_shape = ""
|
||||
for reshape_arg_dim, reshape_arg in enumerate(reshape_args):
|
||||
@@ -409,45 +473,8 @@ class TraceFlow(object):
|
||||
chunk_info["reshape_size"] = reshape_size
|
||||
return chunk_info
|
||||
|
||||
def find_chunk_info(self, input_trace, output_trace, start_idx, end_idx) -> List:
|
||||
"""
|
||||
Find chunk info for a region.
|
||||
|
||||
We are given the region start and region end, and need to find out all chunk info for it.
|
||||
We first loop every dim of start node and end node, to see if we can find dim pair,
|
||||
which is linked in a flow and not computed.
|
||||
If found, we then search flow in the whole region to find out all chunk infos.
|
||||
|
||||
Args:
|
||||
input_trace (List): node's input trace in region
|
||||
output_trace (List): node's output trace in region
|
||||
start_idx (int): region start node index
|
||||
end_idx (int): region end node index
|
||||
|
||||
Returns:
|
||||
chunk_infos: possible regions found
|
||||
"""
|
||||
start_traces = input_trace[start_idx]
|
||||
if len(start_traces) > 1: # TODO need to be removed
|
||||
return []
|
||||
end_trace = output_trace[end_idx]
|
||||
end_node = self.trace_indice.node_list[end_idx]
|
||||
|
||||
chunk_infos = []
|
||||
for end_dim, _ in enumerate(end_trace["indice"]):
|
||||
for start_node, start_trace in start_traces.items():
|
||||
for start_dim, _ in enumerate(start_trace["indice"]):
|
||||
if not self._check_region_start_end(start_node, start_dim, start_idx, end_node, end_dim, end_idx):
|
||||
continue
|
||||
# flow search
|
||||
chunk_info = self.flow_search(start_idx, start_dim, end_idx, end_dim)
|
||||
if chunk_info is None:
|
||||
continue
|
||||
chunk_infos.append(chunk_info)
|
||||
return chunk_infos
|
||||
|
||||
def _check_region_start_end(self, start_node: Node, start_dim: int, start_idx: int, end_node: Node, end_dim: int,
|
||||
end_idx: int) -> bool:
|
||||
def check_region_start_end(self, start_node: Node, start_dim: int, start_idx: int, end_node: Node, end_dim: int,
|
||||
end_idx: int) -> bool:
|
||||
"""
|
||||
check if region start and end is legal
|
||||
"""
|
||||
|
Reference in New Issue
Block a user