mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 12:01:39 +00:00
[autochunk] support transformer (#2526)
This commit is contained in:
@@ -8,9 +8,9 @@ from .utils import (
|
||||
find_chunk_compute_input_and_output_nodes,
|
||||
find_idx_by_name,
|
||||
flat_list,
|
||||
get_node_name,
|
||||
get_node_shape,
|
||||
is_non_compute_node,
|
||||
is_non_compute_node_except_placeholder,
|
||||
)
|
||||
|
||||
|
||||
@@ -79,43 +79,6 @@ class TraceFlow(object):
|
||||
return node_dim
|
||||
return None
|
||||
|
||||
def check_index_duplicate(self, chunk_infos, return_dim=False):
|
||||
input_dim_after_node = {}
|
||||
for input_node_idx, input_node in enumerate(chunk_infos["inputs"]):
|
||||
for k, v in chunk_infos["inputs_dim"][input_node_idx].items():
|
||||
inherit_dim = self._find_inherit_dim(input_node, v, self.trace_indice.node_list[k])
|
||||
if inherit_dim:
|
||||
input_dim_after_node[k] = inherit_dim
|
||||
|
||||
for node in self.trace_indice.node_list[chunk_infos["region"][0]:chunk_infos["region"][1] + 1]:
|
||||
if is_non_compute_node_except_placeholder(node):
|
||||
continue
|
||||
count = 0
|
||||
duplicate_dims = []
|
||||
node_trace_source = self.trace_indice._find_source_trace_from_node(node)
|
||||
for node_dim in range(len(get_node_shape(node))):
|
||||
duplicate_dim = []
|
||||
duplicate_flag = False
|
||||
dim_source = node_trace_source[node_dim]
|
||||
for k, v in dim_source.items():
|
||||
if chunk_infos["region"][0] <= k <= chunk_infos["region"][1]:
|
||||
if k in input_dim_after_node and input_dim_after_node[k] in v:
|
||||
duplicate_flag = True
|
||||
duplicate_dim.append((k, v))
|
||||
duplicate_dims.append(duplicate_dim)
|
||||
if duplicate_flag:
|
||||
count += 1
|
||||
|
||||
if count > 1:
|
||||
if return_dim:
|
||||
return False, duplicate_dims
|
||||
else:
|
||||
return False
|
||||
if return_dim:
|
||||
return True, None
|
||||
else:
|
||||
return True
|
||||
|
||||
def _assgin_single_node_flow(
|
||||
self,
|
||||
arg_node: Node,
|
||||
@@ -225,9 +188,12 @@ class TraceFlow(object):
|
||||
if flow_flag == False:
|
||||
return None
|
||||
|
||||
if len(arg_list) == 2:
|
||||
if any(i in cur_node.name for i in ["add", "mul", "truediv"]):
|
||||
if len(arg_list) >= 2:
|
||||
# need to mark fix dim
|
||||
if any(i == get_node_name(cur_node) for i in ["add", "mul", "truediv", "sub", "where"]):
|
||||
for arg in arg_list:
|
||||
if get_node_shape(arg) is None:
|
||||
continue
|
||||
if not (start_idx <= find_idx_by_name(arg.name, self.trace_indice.node_list) < end_idx):
|
||||
continue
|
||||
arg_chunk_dim = all_node_info[arg]["chunk_dim"]
|
||||
@@ -240,9 +206,8 @@ class TraceFlow(object):
|
||||
return None
|
||||
if i not in arg_fix_dim:
|
||||
arg_fix_dim.append(i)
|
||||
elif "einsum" in cur_node.name:
|
||||
pass
|
||||
elif "matmul" in cur_node.name:
|
||||
elif any(i == get_node_name(cur_node)
|
||||
for i in ["einsum", "matmul", "view", "to", "getitem", "tensor", "type"]):
|
||||
pass
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
@@ -426,7 +391,7 @@ class TraceFlow(object):
|
||||
reshape_size = {}
|
||||
chunk_shape = get_node_shape(chunk_info["outputs"][0])[chunk_info["outputs_dim"]]
|
||||
for node in self.trace_indice.node_list[chunk_region[0]:chunk_region[1] + 1]:
|
||||
if any(i in node.name for i in ["reshape", "view"]):
|
||||
if any(i == get_node_name(node) for i in ["reshape", "view"]):
|
||||
reshape_args = flat_list(node.args[1:])
|
||||
chunk_dim = chunk_info["node_chunk_dim"][node]["chunk_dim"]
|
||||
new_shape = ""
|
||||
@@ -443,3 +408,62 @@ class TraceFlow(object):
|
||||
reshape_size[node.name] = [origin_shape, new_shape]
|
||||
chunk_info["reshape_size"] = reshape_size
|
||||
return chunk_info
|
||||
|
||||
def find_chunk_info(self, input_trace, output_trace, start_idx, end_idx) -> List:
|
||||
"""
|
||||
Find chunk info for a region.
|
||||
|
||||
We are given the region start and region end, and need to find out all chunk info for it.
|
||||
We first loop every dim of start node and end node, to see if we can find dim pair,
|
||||
which is linked in a flow and not computed.
|
||||
If found, we then search flow in the whole region to find out all chunk infos.
|
||||
|
||||
Args:
|
||||
input_trace (List): node's input trace in region
|
||||
output_trace (List): node's output trace in region
|
||||
start_idx (int): region start node index
|
||||
end_idx (int): region end node index
|
||||
|
||||
Returns:
|
||||
chunk_infos: possible regions found
|
||||
"""
|
||||
start_traces = input_trace[start_idx]
|
||||
if len(start_traces) > 1: # TODO need to be removed
|
||||
return []
|
||||
end_trace = output_trace[end_idx]
|
||||
end_node = self.trace_indice.node_list[end_idx]
|
||||
|
||||
chunk_infos = []
|
||||
for end_dim, _ in enumerate(end_trace["indice"]):
|
||||
for start_node, start_trace in start_traces.items():
|
||||
for start_dim, _ in enumerate(start_trace["indice"]):
|
||||
if not self._check_region_start_end(start_node, start_dim, start_idx, end_node, end_dim, end_idx):
|
||||
continue
|
||||
# flow search
|
||||
chunk_info = self.flow_search(start_idx, start_dim, end_idx, end_dim)
|
||||
if chunk_info is None:
|
||||
continue
|
||||
chunk_infos.append(chunk_info)
|
||||
return chunk_infos
|
||||
|
||||
def _check_region_start_end(self, start_node: Node, start_dim: int, start_idx: int, end_node: Node, end_dim: int,
|
||||
end_idx: int) -> bool:
|
||||
"""
|
||||
check if region start and end is legal
|
||||
"""
|
||||
# dim cannot be None
|
||||
if (get_node_shape(end_node) is None or get_node_shape(start_node) is None):
|
||||
return False
|
||||
# dim size cannot be 1
|
||||
if (get_node_shape(end_node)[end_dim] == 1 or get_node_shape(start_node)[start_dim] == 1):
|
||||
return False
|
||||
# must have users
|
||||
if len(end_node.users) == 0:
|
||||
return False
|
||||
# check index source align
|
||||
if not self.check_index_source(start_dim, start_node, start_idx, end_dim, end_node):
|
||||
return False
|
||||
# check index copmute
|
||||
if not self.check_index_compute(start_idx, end_dim, end_node, end_idx):
|
||||
return False
|
||||
return True
|
||||
|
Reference in New Issue
Block a user