[autochunk] support transformer (#2526)

This commit is contained in:
oahzxl
2023-01-31 16:00:06 +08:00
committed by GitHub
parent 6e0faa70e0
commit 63199c6687
20 changed files with 1214 additions and 1084 deletions

View File

@@ -3,7 +3,14 @@ from typing import Dict, List, Tuple
from torch.fx.node import Node
from .utils import find_first_tensor_arg, find_idx_by_name, flat_list, get_node_shape
from .utils import (
find_first_tensor_arg,
find_idx_by_name,
flat_list,
get_module_node_name,
get_node_name,
get_node_shape,
)
class TraceIndice(object):
@@ -36,7 +43,7 @@ class TraceIndice(object):
self.trace_range = []
self.active_node_list = []
def _init_indice_trace_list(self):
def _init_indice_trace_list(self) -> List:
indice_trace_list = []
for n in self.node_list:
if get_node_shape(n) != None:
@@ -54,7 +61,7 @@ class TraceIndice(object):
self.trace_range = trace_range
self.active_node_list = active_node_list
def _add_indice(self):
def _add_indice(self) -> int:
"""
Update the count and return it. To record the idx number.
@@ -64,39 +71,30 @@ class TraceIndice(object):
self.indice_count += 1
return self.indice_count
def _del_dim(self, idx, dim_idx):
def _del_dim(self, idx: int, dim_idx: int) -> None:
"""
delete a dim for indice, compute and source
"""
self.indice_trace_list[idx]["indice"].pop(dim_idx)
self.indice_trace_list[idx]["compute"].pop(dim_idx)
self.indice_trace_list[idx]["source"].pop(dim_idx)
def _add_dim(self, node_idx, dim_idx):
def _add_dim(self, node_idx: int, dim_idx: int) -> None:
"""
add a dim for indice, compute and source
"""
self.indice_trace_list[node_idx]["indice"].insert(dim_idx, self._add_indice())
self.indice_trace_list[node_idx]["compute"].insert(dim_idx, [])
self.indice_trace_list[node_idx]["source"].insert(dim_idx, {})
def _transform_indice(self, node, node_dim):
node_idx = self._find_indice_trace_from_node(node)
dims = list(range(len(node_idx)))
return dims[node_dim]
def _inherit_indice(self, node_from, node_from_dim, node_to, node_to_dim):
node_from_dim = self._transform_indice(node_from, node_from_dim)
node_to_dim = self._transform_indice(node_to, node_to_dim)
node_from_trace = self._find_trace_from_node(node_from)
node_to_trace = self._find_trace_from_node(node_to)
node_to_trace["indice"][node_to_dim] = node_from_trace["indice"][node_from_dim]
node_to_trace["compute"][node_to_dim] = copy.deepcopy(node_from_trace["compute"][node_from_dim])
self._add_source(node_from, node_from_dim, node_to, node_to_dim, init=True)
def _inherit_all_computation(self, node_from, node_to):
node_from_compute = self._find_compute_trace_from_node(node_from)
node_to_compute = self._find_compute_trace_from_node(node_to)
assert len(node_from_compute) == len(node_to_compute)
for i in range(len(node_from_compute)):
self._add_source(node_from, i, node_to, i)
node_to_compute[i] = copy.deepcopy(node_from_compute[i])
def _add_source(self, node_from, node_from_dim, node_to, node_to_dim, init=False):
def _add_source(
self,
node_from: Node,
node_from_dim: int,
node_to: Node,
node_to_dim: int,
init=False,
) -> None:
node_from_dim = self._transform_indice(node_from, node_from_dim)
node_from_trace_source = self._find_source_trace_from_node(node_from)
node_to_dim = self._transform_indice(node_to, node_to_dim)
@@ -119,7 +117,50 @@ class TraceIndice(object):
if d not in node_to_trace_source[node_to_dim][node_idx]:
node_to_trace_source[node_to_dim][node_idx].append(d)
def _mark_computation_from_node(self, node_from, node_to, exclude=None):
def _transform_indice(self, node: Node, node_dim: int) -> int:
node_idx = self._find_indice_trace_from_node(node)
dims = list(range(len(node_idx)))
return dims[node_dim]
def _inherit_indice(
self,
node_from: Node,
node_from_dim: int,
node_to: Node,
node_to_dim: int,
init: bool = True,
) -> None:
"""
node_to's node_to_dim inherit node_from's node_from_dim by indice, compute and source
"""
node_from_dim = self._transform_indice(node_from, node_from_dim)
node_to_dim = self._transform_indice(node_to, node_to_dim)
node_from_trace = self._find_trace_from_node(node_from)
node_to_trace = self._find_trace_from_node(node_to)
if init:
node_to_trace["indice"][node_to_dim] = node_from_trace["indice"][node_from_dim]
node_to_trace["compute"][node_to_dim] = copy.deepcopy(node_from_trace["compute"][node_from_dim])
else:
for j in node_from_trace["compute"][node_from_dim]:
if j not in node_to_trace["compute"][node_to_dim]:
node_to_trace["compute"][node_to_dim].append(j)
self._add_source(node_from, node_from_dim, node_to, node_to_dim, init)
def _inherit_all_indice(self, node_from: Node, node_to: Node) -> None:
"""
inherit all dims with init
"""
# find indice just for assert length
node_from_indice = self._find_indice_trace_from_node(node_from)
node_to_indice = self._find_indice_trace_from_node(node_to)
assert len(node_from_indice) == len(node_to_indice)
for i in range(len(node_from_indice)):
self._inherit_indice(node_from, i, node_to, i, init=True)
def _inherit_more_indice_from_node(self, node_from: Node, node_to: Node, exclude: List = None) -> None:
"""
inheirt indice from node without init
"""
if exclude == None:
exclude = []
else:
@@ -130,12 +171,9 @@ class TraceIndice(object):
for i in range(-1, -min(len(node_from_compute), len(node_to_compute)) - 1, -1):
if self._transform_indice(node_to, i) in exclude:
continue
self._add_source(node_from, i, node_to, i)
for j in node_from_compute[i]:
if j not in node_to_compute[i]:
node_to_compute[i].append(j)
self._inherit_indice(node_from, i, node_to, i, init=False)
def _mark_computation(self, node, idx, dim):
def _mark_computation(self, node: Node, idx: int, dim: int) -> None:
"""
Mark some dims of node as computed.
@@ -152,7 +190,7 @@ class TraceIndice(object):
if idx not in self.indice_trace_list[idx]["compute"][cur_dim]:
self.indice_trace_list[idx]["compute"][cur_dim].append(idx)
def _find_trace_from_node(self, node):
def _find_trace_from_node(self, node: Node) -> Dict:
"""
Find node idx and compute trace by the node.
@@ -166,7 +204,7 @@ class TraceIndice(object):
node_dict = self.indice_trace_list[node_idx]
return node_dict
def _find_source_trace_from_node(self, node):
def _find_source_trace_from_node(self, node: Node) -> List:
"""
Find node source trace by the node.
@@ -180,7 +218,7 @@ class TraceIndice(object):
node_dict = self.indice_trace_list[node_idx]
return node_dict["source"]
def _find_indice_trace_from_node(self, node):
def _find_indice_trace_from_node(self, node) -> List:
"""
Find node idx trace by the node.
@@ -192,7 +230,7 @@ class TraceIndice(object):
node_idx = find_idx_by_name(node.name, self.node_list)
return self.indice_trace_list[node_idx]["indice"]
def _find_compute_trace_from_node(self, node):
def _find_compute_trace_from_node(self, node: Node) -> List:
"""
Find node compute trace by the node.
@@ -204,7 +242,7 @@ class TraceIndice(object):
node_idx = find_idx_by_name(node.name, self.node_list)
return self.indice_trace_list[node_idx]["compute"]
def _assign_indice_as_input(self, node: Node, node_idx: int, input_node=None):
def _assign_indice_as_input(self, node: Node, node_idx: int, input_node=None) -> None:
"""
Assign node's trace as its input node.
@@ -214,15 +252,9 @@ class TraceIndice(object):
"""
if input_node == None:
input_node = find_first_tensor_arg(node)
input_node_idx = find_idx_by_name(input_node.name, self.node_list)
input_node_idx_trace = self.indice_trace_list[input_node_idx]["indice"]
self._inherit_all_indice(input_node, node)
new_idx_trace = copy.deepcopy(input_node_idx_trace)
self.indice_trace_list[node_idx]["indice"] = new_idx_trace
self._inherit_all_computation(input_node, node)
def _assign_all_indice(self, node: Node, node_idx: int):
def _assign_all_indice(self, node: Node, node_idx: int) -> None:
"""
Add new indice for all node's dims.
@@ -238,7 +270,7 @@ class TraceIndice(object):
new_trace.append(self._add_indice())
self.indice_trace_list[node_idx]["indice"] = new_trace
def _assign_transpose_indice(self, node: Node, node_idx: int):
def _assign_transpose_indice(self, node: Node, node_idx: int) -> None:
"""
Assign indice for transpose op.
1. swap input's dim according to transpose args
@@ -255,7 +287,7 @@ class TraceIndice(object):
self._inherit_indice(input_node, tranpose_dim[1], node, tranpose_dim[0])
self._inherit_indice(input_node, tranpose_dim[0], node, tranpose_dim[1])
def _assign_permute_indice(self, node: Node, node_idx: int):
def _assign_permute_indice(self, node: Node, node_idx: int) -> None:
"""
Assign indice for permute op.
1. swap input's dim according to permute args
@@ -272,7 +304,7 @@ class TraceIndice(object):
for idx, d in enumerate(permute_dim):
self._inherit_indice(input_node, d, node, idx)
def _assign_linear_indice(self, node: Node, node_idx: int):
def _assign_linear_indice(self, node: Node, node_idx: int) -> None:
"""
Assign indice for linear op.
1. copy trace from input node and change last indice accroding to weight
@@ -293,7 +325,23 @@ class TraceIndice(object):
self._mark_computation(node, node_idx, [-1])
def _assign_matmul_indice(self, node: Node, node_idx: int):
def _assign_addmm_indice(self, node: Node, node_idx: int) -> None:
"""
Assign indice for addmm op.
Args:
node (node)
node_idx (int)
"""
bias, input_node, weight = node.args
self._assign_indice_as_input(node, node_idx, input_node)
self._inherit_indice(weight, 1, node, -1)
self._inherit_indice(bias, -1, node, -1)
self._mark_computation(node, node_idx, [-1])
def _assign_matmul_indice(self, node: Node, node_idx: int) -> None:
"""
Assign indice for matmul op.
1. copy trace from matmul_left and change last indice accroding to matmul_right. (assert they have same length)
@@ -310,7 +358,7 @@ class TraceIndice(object):
self._assign_indice_as_input(node, node_idx, matmul_left)
self._inherit_indice(matmul_right, -1, node, -1)
self._mark_computation_from_node(matmul_right, node, [-1, -2])
self._inherit_more_indice_from_node(matmul_right, node, [-1, -2])
self._mark_computation(node, node_idx, [-1])
def _assign_layernorm_indice(self, node, idx):
@@ -341,14 +389,13 @@ class TraceIndice(object):
for node_in in node.args:
if type(node_in) == type(node):
nodes_in.append(node_in)
self._mark_computation_from_node(node_in, node)
assert len(nodes_in) <= 2
self._inherit_more_indice_from_node(node_in, node)
def _assgin_no_change_indice(self, node, idx):
self._assign_indice_as_input(node, idx)
for node_in in node.args:
if type(node_in) == type(node):
self._mark_computation_from_node(node_in, node)
self._inherit_more_indice_from_node(node_in, node)
def _assign_einsum_indice(self, node, idx):
"""
@@ -365,7 +412,7 @@ class TraceIndice(object):
left, right = patterns.split("->")
left = left.split(",")
if '...' in right:
if "..." in right:
replace_list = "!@#$%^&*"
target_len = len(get_node_shape(node))
add_len = target_len - len(right) + 3
@@ -399,7 +446,22 @@ class TraceIndice(object):
self._assign_indice_as_input(node, idx)
self._mark_computation(node, idx, [node.kwargs["dim"]])
def _assign_unsqueeze_indice(self, node: Node, node_idx: int):
def _assign_split_indice(self, node: Node, node_idx: int) -> None:
"""
Assign indice for split op.
Args:
node (node)
node_idx (int)
"""
for _ in range(len(get_node_shape(node.args[0]))):
self._add_dim(node_idx, 0)
self._assign_indice_as_input(node, node_idx)
dim_idx = node.kwargs["dim"]
self._del_dim(node_idx, dim_idx)
self._add_dim(node_idx, dim_idx)
def _assign_unsqueeze_indice(self, node: Node, node_idx: int) -> None:
"""
Assign indice for unsqueeze op.
1. assign new indice for unsqueeze dim
@@ -416,18 +478,7 @@ class TraceIndice(object):
dim_idx = list(range(len(get_node_shape(node))))[dim_idx]
self._add_dim(node_idx, dim_idx)
def _assign_dropout_indice(self, node: Node, node_idx: int):
"""
Assign indice for unsqueeze op.
1. assign new indice for unsqueeze dim
Args:
node (node)
node_idx (int)
"""
self._assign_indice_as_input(node, node_idx)
def _assign_ones_like_indice(self, node: Node, node_idx: int):
def _assign_ones_like_indice(self, node: Node, node_idx: int) -> None:
"""
Assign indice for oneslike op.
1. assign new indice for all dim
@@ -438,7 +489,7 @@ class TraceIndice(object):
"""
self._assign_all_indice(node, node_idx)
def _assign_cat_indice(self, node: Node, node_idx: int):
def _assign_cat_indice(self, node: Node, node_idx: int) -> None:
"""
Assign indice for cat op.
@@ -449,12 +500,12 @@ class TraceIndice(object):
nodes_in = flat_list(node.args[0])
self._assign_indice_as_input(node, node_idx, input_node=nodes_in[0])
for n in nodes_in[1:]:
self._mark_computation_from_node(n, node)
self._inherit_more_indice_from_node(n, node)
cat_dim = node.kwargs["dim"]
self._del_dim(node_idx, cat_dim)
self._add_dim(node_idx, cat_dim)
def _assign_sum_indice(self, node: Node, node_idx: int):
def _assign_sum_indice(self, node: Node, node_idx: int) -> None:
"""
Assign indice for sum op.
@@ -466,11 +517,46 @@ class TraceIndice(object):
self._add_dim(node_idx, 0)
self._assign_indice_as_input(node, node_idx, input_node=nodes_in[0])
for n in nodes_in[1:]:
self._mark_computation_from_node(n, node)
self._inherit_more_indice_from_node(n, node)
cat_dim = node.kwargs["dim"]
self._del_dim(node_idx, cat_dim)
def _assign_getitem_indice(self, node: Node, node_idx: int):
def _assign_arange_indice(self, node: Node, node_idx: int) -> None:
"""
Assign indice for arange op.
Args:
node (node)
node_idx (int)
"""
self._assign_all_indice(node, node_idx)
def _assign_tensor_indice(self, node: Node, node_idx: int) -> None:
"""
Assign indice for tensor op.
Args:
node (node)
node_idx (int)
"""
if len(get_node_shape(node)) == 0:
return
else:
raise NotImplementedError()
def _assign_embedding_indice(self, node: Node, node_idx: int) -> None:
"""
Assign indice for embedding op.
Args:
node (node)
node_idx (int)
"""
self._del_dim(node_idx, -1)
self._assign_indice_as_input(node, node_idx)
self._add_dim(node_idx, -1)
def _assign_getitem_indice(self, node: Node, node_idx: int) -> None:
"""
Assign indice for getitem.
getitem can act like slice sometimes
@@ -480,6 +566,19 @@ class TraceIndice(object):
node_idx (int)
"""
node_args = flat_list(node.args[1:])
# deal with split
if get_node_name(node.args[0]) == "split":
self._assign_indice_as_input(node, node_idx)
self._del_dim(node_idx, node.args[0].kwargs["dim"])
self._add_dim(node_idx, node.args[0].kwargs["dim"])
return
# skip non tensor
if get_node_shape(node) is None:
return
# find if slice
flag = False
for node_arg in node_args:
node_arg_str = str(node_arg)
@@ -528,7 +627,7 @@ class TraceIndice(object):
else:
raise NotImplementedError()
def _assign_view_reshape_indice(self, node: Node, node_idx: int):
def _assign_view_reshape_indice(self, node: Node, node_idx: int) -> None:
"""
Assign indice for view and reshape op.
1. get origin shape and target shape by meta info.
@@ -536,7 +635,7 @@ class TraceIndice(object):
3. determine changed dim, and assgin indice for generated dim.
4. log changed dim and generated dim for restore
5. inherit computation.
6. TODO: look into view list to see whether the view is associated with other,
6. look into view list to see whether the view is associated with other,
if so assgin equal dim according to previous view.
Args:
@@ -552,7 +651,7 @@ class TraceIndice(object):
if isinstance(unflated_args[i], int):
target_shape.append(unflated_args[i])
else:
target_shape.append(unflated_args[i].meta["fwd_out"][0])
target_shape.extend(unflated_args[i].meta["fwd_out"])
# compute the value of -1
if -1 in target_shape:
@@ -579,17 +678,36 @@ class TraceIndice(object):
dim_from = [dim_equal.index(False)]
dim_to = [dim_equal.index(False), dim_equal.index(False) + 1]
self._del_dim(node_idx, -1)
elif len_diff == 0:
# dim equal
dim_equal = [i == j for i, j in zip(origin_shape, target_shape[:-1])]
dim_from = []
dim_to = []
else:
raise NotImplementedError("shape" + str(origin_shape) + "and" + str(target_shape) + "view not implemented")
# get new indice
origin_trace = self._find_indice_trace_from_node(origin_node)
self._assign_indice_as_input(node, node_idx, origin_node)
idx_from = [origin_trace[i] for i in dim_from]
dim_from.reverse()
for i in dim_from:
self._del_dim(node_idx, i)
for i in dim_to:
self._add_dim(node_idx, i)
dim_from.reverse()
# search view list
for view_node, view_dict in self.indice_view_list.items():
if (view_dict["idx_to"] == idx_from and view_dict["dim_to"] == dim_from
and view_dict["dim_from"] == dim_to):
# inheirt indice from current node
for dim_to_i in dim_to:
for dim_from_i in dim_from:
self._inherit_indice(origin_node, dim_from_i, node, dim_to_i, init=False)
# inherid indice from input node of last view
for dim_to_i in dim_to:
self._inherit_indice(view_node.args[0], dim_to_i, node, dim_to_i, init=False)
# inherit computation
compute_log = self._find_compute_trace_from_node(origin_node)
@@ -630,7 +748,7 @@ class TraceIndice(object):
# clear compute
for dim_compute in trace["compute"]:
for i in range(len(dim_compute) - 1, -1, -1):
if dim_compute[i] < trace_range[0] and dim_compute[i] not in active_nodes:
if (dim_compute[i] < trace_range[0] and dim_compute[i] not in active_nodes):
dim_compute.pop(i)
continue
# clear source
@@ -639,59 +757,82 @@ class TraceIndice(object):
if k < trace_range[0] and k not in active_nodes:
dim_source.pop(k)
def trace_indice(self):
def trace_indice(self) -> None:
for idx, node in enumerate(self.node_list):
node_name = get_node_name(node)
if node.op == "placeholder":
self._assign_all_indice(node, idx)
elif node.op == "call_method":
if "transpose" in node.name:
if "transpose" == node_name:
self._assign_transpose_indice(node, idx)
elif "permute" in node.name:
elif "permute" == node_name:
self._assign_permute_indice(node, idx)
elif "view" in node.name or "reshape" in node.name:
elif "view" == node_name or "reshape" == node_name:
self._assign_view_reshape_indice(node, idx)
elif "unsqueeze" in node.name:
elif "unsqueeze" == node_name:
self._assign_unsqueeze_indice(node, idx)
elif any(i in node.name for i in ["to", "contiguous", "clone"]):
elif "split" == node_name:
self._assign_split_indice(node, idx)
elif any(i == node_name for i in ["to", "contiguous", "clone", "type"]):
self._assgin_no_change_indice(node, idx)
elif "new_ones" in node.name:
elif "new_ones" == node_name:
self._assign_ones_like_indice(node, idx)
else:
raise NotImplementedError(node.name, "method not implemented yet!")
elif node.op == "call_function":
if "linear" in node.name:
self._assign_linear_indice(node, idx)
elif "cat" in node.name:
self._assign_cat_indice(node, idx)
elif "matmul" in node.name:
self._assign_matmul_indice(node, idx)
elif "softmax" in node.name:
self._assign_softmax_indice(node, idx)
elif any(n in node.name for n in ["mul", "add", "sigmoid", "relu", "sub", "truediv"]):
self._assign_elementwise_indice(node, idx)
elif "ones_like" in node.name:
self._assign_ones_like_indice(node, idx)
elif "dropout" in node.name:
self._assign_dropout_indice(node, idx)
elif "einsum" in node.name:
self._assign_einsum_indice(node, idx)
elif "sum" in node.name:
self._assign_sum_indice(node, idx)
elif "layer_norm" in node.name:
self._assign_layernorm_indice(node, idx)
elif "getitem" in node.name:
self._assign_getitem_indice(node, idx)
elif any(i in node.name for i in ["getattr", "getitem", "eq", "_assert"]):
elif any(i == node_name for i in ["size"]):
continue
else:
raise NotImplementedError(node.name, "function not implemented yet!")
elif node.op == "call_module":
if any(n in node.name for n in ["layernorm", "norm"]):
raise NotImplementedError(node_name, "method not implemented yet!")
elif node.op == "call_function":
if "linear" == node_name:
self._assign_linear_indice(node, idx)
elif "cat" == node_name:
self._assign_cat_indice(node, idx)
elif "matmul" == node_name:
self._assign_matmul_indice(node, idx)
elif "softmax" == node_name:
self._assign_softmax_indice(node, idx)
elif any(n == node_name for n in [
"mul",
"add",
"sigmoid",
"relu",
"sub",
"truediv",
"pow",
"dropout",
"where",
"tanh",
]):
self._assign_elementwise_indice(node, idx)
elif "ones_like" == node_name:
self._assign_ones_like_indice(node, idx)
elif "einsum" == node_name:
self._assign_einsum_indice(node, idx)
elif "sum" == node_name:
self._assign_sum_indice(node, idx)
elif "layer_norm" == node_name:
self._assign_layernorm_indice(node, idx)
elif any(n in node.name for n in ["sigmoid", "dropout", "relu"]):
elif "getitem" == node_name:
self._assign_getitem_indice(node, idx)
elif "addmm" == node_name:
self._assign_addmm_indice(node, idx)
elif "arange" == node_name:
self._assign_arange_indice(node, idx)
elif "tensor" == node_name:
self._assign_arange_indice(node, idx)
elif any(i == node_name for i in ["getattr", "eq", "_assert_is_none", "_assert", "finfo"]):
continue
else:
raise NotImplementedError(node_name, "function not implemented yet!")
elif node.op == "call_module":
node_name = get_module_node_name(node)
if "layernorm" == node_name:
self._assign_layernorm_indice(node, idx)
elif "embedding" == node_name:
self._assign_embedding_indice(node, idx)
elif any(n == node_name for n in ["sigmoid", "dropout", "relu"]):
self._assign_elementwise_indice(node, idx)
else:
raise NotImplementedError(node.name, "module not implemented yet!")
raise NotImplementedError(node_name, "module not implemented yet!")
elif node.op == "get_attr":
self._assign_all_indice(node, idx) # get param
elif node.op == "output":