mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-13 05:01:44 +00:00
[autochunk] support autochunk on evoformer (#2497)
This commit is contained in:
@@ -3,7 +3,7 @@ from typing import Dict, List, Tuple
|
||||
|
||||
from torch.fx.node import Node
|
||||
|
||||
from .utils import find_first_tensor_arg, find_idx_by_name, get_node_shape, unflat_list
|
||||
from .utils import find_first_tensor_arg, find_idx_by_name, flat_list, get_node_shape
|
||||
|
||||
|
||||
class TraceIndice(object):
|
||||
@@ -28,7 +28,7 @@ class TraceIndice(object):
|
||||
node_list (List)
|
||||
"""
|
||||
|
||||
def __init__(self, node_list: List) -> None:
|
||||
def __init__(self, node_list: List[Node]) -> None:
|
||||
self.node_list = node_list
|
||||
self.indice_trace_list = self._init_indice_trace_list()
|
||||
self.indice_view_list = {}
|
||||
@@ -198,7 +198,7 @@ class TraceIndice(object):
|
||||
node_idx = find_idx_by_name(node.name, self.node_list)
|
||||
return self.indice_trace_list[node_idx]["compute"]
|
||||
|
||||
def _assign_indice_as_input(self, node, node_idx, input_node=None):
|
||||
def _assign_indice_as_input(self, node: Node, node_idx: int, input_node=None):
|
||||
"""
|
||||
Assign node's trace as its input node.
|
||||
|
||||
@@ -216,7 +216,7 @@ class TraceIndice(object):
|
||||
|
||||
self._inherit_all_computation(input_node, node)
|
||||
|
||||
def _assign_all_indice(self, node, node_idx):
|
||||
def _assign_all_indice(self, node: Node, node_idx: int):
|
||||
"""
|
||||
Add new indice for all node's dims.
|
||||
|
||||
@@ -232,7 +232,7 @@ class TraceIndice(object):
|
||||
new_trace.append(self._add_indice())
|
||||
self.indice_trace_list[node_idx]["indice"] = new_trace
|
||||
|
||||
def _assign_transpose_indice(self, node, node_idx):
|
||||
def _assign_transpose_indice(self, node: Node, node_idx: int):
|
||||
"""
|
||||
Assign indice for transpose op.
|
||||
1. swap input's dim according to transpose args
|
||||
@@ -249,7 +249,7 @@ class TraceIndice(object):
|
||||
self._inherit_indice(input_node, tranpose_dim[1], node, tranpose_dim[0])
|
||||
self._inherit_indice(input_node, tranpose_dim[0], node, tranpose_dim[1])
|
||||
|
||||
def _assign_permute_indice(self, node, node_idx):
|
||||
def _assign_permute_indice(self, node: Node, node_idx: int):
|
||||
"""
|
||||
Assign indice for permute op.
|
||||
1. swap input's dim according to permute args
|
||||
@@ -259,14 +259,14 @@ class TraceIndice(object):
|
||||
node (node)
|
||||
node_idx (int)
|
||||
"""
|
||||
permute_dim = unflat_list(node.args[1:])
|
||||
permute_dim = flat_list(node.args[1:])
|
||||
input_node = node.args[0]
|
||||
|
||||
self._assign_indice_as_input(node, node_idx, input_node)
|
||||
for idx, d in enumerate(permute_dim):
|
||||
self._inherit_indice(input_node, d, node, idx)
|
||||
|
||||
def _assign_linear_indice(self, node, node_idx):
|
||||
def _assign_linear_indice(self, node: Node, node_idx: int):
|
||||
"""
|
||||
Assign indice for linear op.
|
||||
1. copy trace from input node and change last indice accroding to weight
|
||||
@@ -287,7 +287,7 @@ class TraceIndice(object):
|
||||
|
||||
self._mark_computation(node, node_idx, [-1])
|
||||
|
||||
def _assign_matmul_indice(self, node, node_idx):
|
||||
def _assign_matmul_indice(self, node: Node, node_idx: int):
|
||||
"""
|
||||
Assign indice for matmul op.
|
||||
1. copy trace from matmul_left and change last indice accroding to matmul_right. (assert they have same length)
|
||||
@@ -393,7 +393,7 @@ class TraceIndice(object):
|
||||
self._assign_indice_as_input(node, idx)
|
||||
self._mark_computation(node, idx, [node.kwargs["dim"]])
|
||||
|
||||
def _assign_unsqueeze_indice(self, node, node_idx):
|
||||
def _assign_unsqueeze_indice(self, node: Node, node_idx: int):
|
||||
"""
|
||||
Assign indice for unsqueeze op.
|
||||
1. assign new indice for unsqueeze dim
|
||||
@@ -404,9 +404,13 @@ class TraceIndice(object):
|
||||
"""
|
||||
self._del_dim(node_idx, -1)
|
||||
self._assign_indice_as_input(node, node_idx)
|
||||
self._add_dim(node_idx, node.args[1])
|
||||
dim_idx = node.args[1]
|
||||
# unsqueeze(-1) = unsqueeze(shape_num + 1)
|
||||
if dim_idx < 0:
|
||||
dim_idx = list(range(len(get_node_shape(node))))[dim_idx]
|
||||
self._add_dim(node_idx, dim_idx)
|
||||
|
||||
def _assign_dropout_indice(self, node, node_idx):
|
||||
def _assign_dropout_indice(self, node: Node, node_idx: int):
|
||||
"""
|
||||
Assign indice for unsqueeze op.
|
||||
1. assign new indice for unsqueeze dim
|
||||
@@ -417,7 +421,7 @@ class TraceIndice(object):
|
||||
"""
|
||||
self._assign_indice_as_input(node, node_idx)
|
||||
|
||||
def _assign_ones_like_indice(self, node, node_idx):
|
||||
def _assign_ones_like_indice(self, node: Node, node_idx: int):
|
||||
"""
|
||||
Assign indice for oneslike op.
|
||||
1. assign new indice for all dim
|
||||
@@ -428,7 +432,47 @@ class TraceIndice(object):
|
||||
"""
|
||||
self._assign_all_indice(node, node_idx)
|
||||
|
||||
def _assign_view_reshape_indice(self, node, node_idx):
|
||||
def _assign_getitem_indice(self, node: Node, node_idx: int):
|
||||
"""
|
||||
Assign indice for getitem.
|
||||
getitem can act like slice sometimes
|
||||
|
||||
Args:
|
||||
node (node)
|
||||
node_idx (int)
|
||||
"""
|
||||
node_args = flat_list(node.args[1:])
|
||||
if not any(i == str(node_arg) for i in ["None", "Ellipsis"] for node_arg in node_args):
|
||||
return
|
||||
|
||||
# node args should be like [Ellipsis, slice(start, step, end), None]
|
||||
node_shape = get_node_shape(node)
|
||||
origin_idx_count = 0
|
||||
new_idx_count = 0
|
||||
new_dim_num = sum([1 if str(i) == "None" else 0 for i in node_args])
|
||||
for _ in range(new_dim_num):
|
||||
self._del_dim(node_idx, 0)
|
||||
self._assign_indice_as_input(node, node_idx)
|
||||
|
||||
for _, node_arg in enumerate(node_args):
|
||||
node_arg_str = str(node_arg)
|
||||
# Ellipsis means [..., ]
|
||||
if "Ellipsis" == node_arg_str:
|
||||
shape_gap = len(node_shape) - len(node_args) + 1
|
||||
origin_idx_count += shape_gap
|
||||
new_idx_count += shape_gap
|
||||
# slice(None, None, None) means all indexes, doesn't support other slice
|
||||
elif "slice(None, None, None)" == node_arg_str:
|
||||
origin_idx_count += 1
|
||||
new_idx_count += 1
|
||||
# None means a new dim
|
||||
elif "None" == node_arg_str:
|
||||
self._add_dim(node_idx, new_idx_count)
|
||||
new_idx_count += 1
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
def _assign_view_reshape_indice(self, node: Node, node_idx: int):
|
||||
"""
|
||||
Assign indice for view and reshape op.
|
||||
1. get origin shape and target shape by meta info.
|
||||
@@ -447,7 +491,7 @@ class TraceIndice(object):
|
||||
origin_node = node.args[0]
|
||||
origin_shape = origin_node.meta["tensor_meta"].shape
|
||||
target_shape = []
|
||||
unflated_args = unflat_list(node.args)
|
||||
unflated_args = flat_list(node.args)
|
||||
for i in range(1, len(unflated_args)):
|
||||
if isinstance(unflated_args[i], int):
|
||||
target_shape.append(unflated_args[i])
|
||||
@@ -544,6 +588,8 @@ class TraceIndice(object):
|
||||
self._assign_einsum_indice(node, idx)
|
||||
elif "layer_norm" in node.name:
|
||||
self._assign_layernorm_indice(node, idx)
|
||||
elif "getitem" in node.name:
|
||||
self._assign_getitem_indice(node, idx)
|
||||
elif any(i in node.name for i in ["getattr", "getitem", "eq", "_assert"]):
|
||||
continue
|
||||
else:
|
||||
|
Reference in New Issue
Block a user