mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 11:32:10 +00:00
[autochunk] support autochunk on evoformer (#2497)
This commit is contained in:
@@ -1,8 +1,13 @@
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from torch.fx.node import Node
|
||||
|
||||
from .trace_indice import TraceIndice
|
||||
from .utils import (
|
||||
find_chunk_all_input_nodes,
|
||||
find_chunk_compute_input_and_output_nodes,
|
||||
find_idx_by_name,
|
||||
flat_list,
|
||||
get_node_shape,
|
||||
is_non_compute_node,
|
||||
is_non_compute_node_except_placeholder,
|
||||
@@ -171,7 +176,7 @@ class TraceFlow(object):
|
||||
# get cur node info
|
||||
cur_node_chunk_dim = all_node_info[cur_node]["chunk_dim"]
|
||||
cur_node_fix_dim = all_node_info[cur_node]["fix_dim"]
|
||||
if cur_node_chunk_dim:
|
||||
if cur_node_chunk_dim is not None:
|
||||
cur_node_compute = self.trace_indice._find_compute_trace_from_node(cur_node)
|
||||
cur_node_source = self.trace_indice._find_source_trace_from_node(cur_node)
|
||||
else:
|
||||
@@ -223,15 +228,32 @@ class TraceFlow(object):
|
||||
cur_node_list = next_node_list
|
||||
return all_node_info
|
||||
|
||||
def _get_input_nodes_dim(self, inputs, start_idx, end_idx, all_node_info):
|
||||
def _get_input_nodes_dim(self, inputs: List[Node], start_idx: int, end_idx: int, all_node_info: Dict) -> Tuple:
|
||||
"""
|
||||
Get chunk dim for every input node for their every entry, remove unchunked nodes
|
||||
|
||||
Args:
|
||||
inputs (List[Node]): input nodes
|
||||
all_node_info (Dict): describe all node's chunk dim and fix dim
|
||||
start_idx (int): chunk start idx
|
||||
end_idx (int): chunk end idx
|
||||
|
||||
Returns:
|
||||
inputs (List(Node)): new inputs
|
||||
inputs_dim (List): chunk dim for inputs
|
||||
"""
|
||||
inputs_dim = []
|
||||
remove_inputs = []
|
||||
for input_node in inputs:
|
||||
input_dict = {}
|
||||
input_node_idx = find_idx_by_name(input_node.name, self.trace_indice.node_list)
|
||||
for user in input_node.users.keys():
|
||||
# skip non compute
|
||||
if is_non_compute_node(user):
|
||||
continue
|
||||
# untraced node, mostly non compute
|
||||
if user not in all_node_info:
|
||||
continue
|
||||
user_idx = find_idx_by_name(user.name, self.trace_indice.node_list)
|
||||
if start_idx <= user_idx <= end_idx:
|
||||
chunk_dim = all_node_info[user]["chunk_dim"]
|
||||
@@ -245,12 +267,24 @@ class TraceFlow(object):
|
||||
remove_inputs.append(input_node)
|
||||
else:
|
||||
inputs_dim.append(input_dict)
|
||||
# remove unchunked inputs
|
||||
for i in remove_inputs:
|
||||
if i in inputs:
|
||||
inputs.remove(i)
|
||||
return inputs, inputs_dim
|
||||
|
||||
def _get_prepose_nodes(self, all_node_info, start_idx, end_idx):
|
||||
def _get_prepose_nodes(self, all_node_info: Dict, start_idx: int, end_idx: int) -> List[Node]:
|
||||
"""
|
||||
get all useless nodes in chunk region and prepose them
|
||||
|
||||
Args:
|
||||
all_node_info (Dict): describe all node's chunk dim and fix dim
|
||||
start_idx (int): chunk start idx
|
||||
end_idx (int): chunk end idx
|
||||
|
||||
Returns:
|
||||
List[Node]: all nodes to be preposed
|
||||
"""
|
||||
# get all possible prepose nodes
|
||||
maybe_prepose_nodes = []
|
||||
for node, node_info in all_node_info.items():
|
||||
@@ -276,7 +310,7 @@ class TraceFlow(object):
|
||||
for cur_prepose_node in tmp_cur_prepose_nodes:
|
||||
if prepose_flag == False:
|
||||
break
|
||||
for cur_prepose_node_arg in cur_prepose_node.args:
|
||||
for cur_prepose_node_arg in cur_prepose_node.all_input_nodes:
|
||||
if type(cur_prepose_node_arg) != type(cur_prepose_node):
|
||||
continue
|
||||
# out of loop
|
||||
@@ -360,19 +394,28 @@ class TraceFlow(object):
|
||||
return chunk_info
|
||||
|
||||
def _reassgin_reshape_size(self, chunk_info):
|
||||
"""
|
||||
Some shape args in reshape may have changed due to chunk
|
||||
reassgin those changed shape
|
||||
"""
|
||||
chunk_region = chunk_info["region"]
|
||||
reshape_size = {}
|
||||
chunk_shape = get_node_shape(chunk_info["outputs"][0])[chunk_info["outputs_dim"]]
|
||||
for node in self.trace_indice.node_list[chunk_region[0]:chunk_region[1] + 1]:
|
||||
if any(i in node.name for i in ["reshape", "view"]):
|
||||
reshape_args = node.args[1:]
|
||||
reshape_log = self.trace_indice.indice_view_list[node]
|
||||
reshape_args = flat_list(node.args[1:])
|
||||
chunk_dim = chunk_info["node_chunk_dim"][node]["chunk_dim"]
|
||||
reshape_size[node.name] = {}
|
||||
new_shape = ""
|
||||
for reshape_arg_dim, reshape_arg in enumerate(reshape_args):
|
||||
if reshape_arg_dim in reshape_log["dim_to"]:
|
||||
continue
|
||||
if reshape_arg_dim == chunk_dim:
|
||||
reshape_size[node.name][reshape_arg.name] = ("min(chunk_size, %d - chunk_idx)" % chunk_shape)
|
||||
new_shape += "min(chunk_size, %d - chunk_idx), " % chunk_shape
|
||||
else:
|
||||
if isinstance(reshape_arg, int):
|
||||
new_shape += "%s, " % str(reshape_arg)
|
||||
else:
|
||||
new_shape += "%s, " % reshape_arg.name
|
||||
new_shape = new_shape[:-2]
|
||||
origin_shape = str(reshape_args)[1:-1]
|
||||
reshape_size[node.name] = [origin_shape, new_shape]
|
||||
chunk_info["reshape_size"] = reshape_size
|
||||
return chunk_info
|
||||
|
Reference in New Issue
Block a user