[autochunk] support transformer (#2526)

This commit is contained in:
oahzxl
2023-01-31 16:00:06 +08:00
committed by GitHub
parent 6e0faa70e0
commit 63199c6687
20 changed files with 1214 additions and 1084 deletions

View File

@@ -8,9 +8,9 @@ from .utils import (
find_chunk_compute_input_and_output_nodes,
find_idx_by_name,
flat_list,
get_node_name,
get_node_shape,
is_non_compute_node,
is_non_compute_node_except_placeholder,
)
@@ -79,43 +79,6 @@ class TraceFlow(object):
return node_dim
return None
def check_index_duplicate(self, chunk_infos, return_dim=False):
input_dim_after_node = {}
for input_node_idx, input_node in enumerate(chunk_infos["inputs"]):
for k, v in chunk_infos["inputs_dim"][input_node_idx].items():
inherit_dim = self._find_inherit_dim(input_node, v, self.trace_indice.node_list[k])
if inherit_dim:
input_dim_after_node[k] = inherit_dim
for node in self.trace_indice.node_list[chunk_infos["region"][0]:chunk_infos["region"][1] + 1]:
if is_non_compute_node_except_placeholder(node):
continue
count = 0
duplicate_dims = []
node_trace_source = self.trace_indice._find_source_trace_from_node(node)
for node_dim in range(len(get_node_shape(node))):
duplicate_dim = []
duplicate_flag = False
dim_source = node_trace_source[node_dim]
for k, v in dim_source.items():
if chunk_infos["region"][0] <= k <= chunk_infos["region"][1]:
if k in input_dim_after_node and input_dim_after_node[k] in v:
duplicate_flag = True
duplicate_dim.append((k, v))
duplicate_dims.append(duplicate_dim)
if duplicate_flag:
count += 1
if count > 1:
if return_dim:
return False, duplicate_dims
else:
return False
if return_dim:
return True, None
else:
return True
def _assgin_single_node_flow(
self,
arg_node: Node,
@@ -225,9 +188,12 @@ class TraceFlow(object):
if flow_flag == False:
return None
if len(arg_list) == 2:
if any(i in cur_node.name for i in ["add", "mul", "truediv"]):
if len(arg_list) >= 2:
# need to mark fix dim
if any(i == get_node_name(cur_node) for i in ["add", "mul", "truediv", "sub", "where"]):
for arg in arg_list:
if get_node_shape(arg) is None:
continue
if not (start_idx <= find_idx_by_name(arg.name, self.trace_indice.node_list) < end_idx):
continue
arg_chunk_dim = all_node_info[arg]["chunk_dim"]
@@ -240,9 +206,8 @@ class TraceFlow(object):
return None
if i not in arg_fix_dim:
arg_fix_dim.append(i)
elif "einsum" in cur_node.name:
pass
elif "matmul" in cur_node.name:
elif any(i == get_node_name(cur_node)
for i in ["einsum", "matmul", "view", "to", "getitem", "tensor", "type"]):
pass
else:
raise NotImplementedError()
@@ -426,7 +391,7 @@ class TraceFlow(object):
reshape_size = {}
chunk_shape = get_node_shape(chunk_info["outputs"][0])[chunk_info["outputs_dim"]]
for node in self.trace_indice.node_list[chunk_region[0]:chunk_region[1] + 1]:
if any(i in node.name for i in ["reshape", "view"]):
if any(i == get_node_name(node) for i in ["reshape", "view"]):
reshape_args = flat_list(node.args[1:])
chunk_dim = chunk_info["node_chunk_dim"][node]["chunk_dim"]
new_shape = ""
@@ -443,3 +408,62 @@ class TraceFlow(object):
reshape_size[node.name] = [origin_shape, new_shape]
chunk_info["reshape_size"] = reshape_size
return chunk_info
def find_chunk_info(self, input_trace, output_trace, start_idx, end_idx) -> List:
"""
Find chunk info for a region.
We are given the region start and region end, and need to find out all chunk info for it.
We first loop every dim of start node and end node, to see if we can find dim pair,
which is linked in a flow and not computed.
If found, we then search flow in the whole region to find out all chunk infos.
Args:
input_trace (List): node's input trace in region
output_trace (List): node's output trace in region
start_idx (int): region start node index
end_idx (int): region end node index
Returns:
chunk_infos: possible regions found
"""
start_traces = input_trace[start_idx]
if len(start_traces) > 1: # TODO need to be removed
return []
end_trace = output_trace[end_idx]
end_node = self.trace_indice.node_list[end_idx]
chunk_infos = []
for end_dim, _ in enumerate(end_trace["indice"]):
for start_node, start_trace in start_traces.items():
for start_dim, _ in enumerate(start_trace["indice"]):
if not self._check_region_start_end(start_node, start_dim, start_idx, end_node, end_dim, end_idx):
continue
# flow search
chunk_info = self.flow_search(start_idx, start_dim, end_idx, end_dim)
if chunk_info is None:
continue
chunk_infos.append(chunk_info)
return chunk_infos
def _check_region_start_end(self, start_node: Node, start_dim: int, start_idx: int, end_node: Node, end_dim: int,
end_idx: int) -> bool:
"""
check if region start and end is legal
"""
# dim cannot be None
if (get_node_shape(end_node) is None or get_node_shape(start_node) is None):
return False
# dim size cannot be 1
if (get_node_shape(end_node)[end_dim] == 1 or get_node_shape(start_node)[start_dim] == 1):
return False
# must have users
if len(end_node.users) == 0:
return False
# check index source align
if not self.check_index_source(start_dim, start_node, start_idx, end_dim, end_node):
return False
# check index copmute
if not self.check_index_compute(start_idx, end_dim, end_node, end_idx):
return False
return True