mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-12 20:54:35 +00:00
[autochunk] support vit (#3084)
support vit for autochunk * support some new ops for vit * fix some bugs * add test for vit
This commit is contained in:
@@ -74,6 +74,9 @@ class TraceIndice(object):
|
||||
"""
|
||||
add a dim for indice, compute and source
|
||||
"""
|
||||
# need to remap if dim_idx < 0, e.g. -1
|
||||
if dim_idx < 0:
|
||||
dim_idx = list(range(len(self.indice_trace_list[node_idx]["indice"]) + 1))[dim_idx]
|
||||
self.indice_trace_list[node_idx]["indice"].insert(dim_idx, self._add_indice())
|
||||
self.indice_trace_list[node_idx]["compute"].insert(dim_idx, [])
|
||||
self.indice_trace_list[node_idx]["source"].insert(dim_idx, {})
|
||||
@@ -575,6 +578,60 @@ class TraceIndice(object):
|
||||
cat_dim = node.kwargs["dim"]
|
||||
self._del_dim(node_idx, cat_dim)
|
||||
|
||||
def _assign_flatten_indice(self, node: Node, node_idx: int) -> None:
|
||||
"""
|
||||
Assign indice for flatten op.
|
||||
|
||||
Args:
|
||||
node (node)
|
||||
node_idx (int)
|
||||
"""
|
||||
nodes_in = node.args[0]
|
||||
nodes_in_shape = get_node_shape(nodes_in)
|
||||
flatten_start_dim = node.args[1]
|
||||
flatten_dim_num = len(nodes_in_shape) - flatten_start_dim - 1
|
||||
assert flatten_dim_num > 0
|
||||
for _ in range(flatten_dim_num):
|
||||
self._add_dim(node_idx, 0)
|
||||
self._assign_indice_as_input(node, node_idx, nodes_in)
|
||||
for _ in range(flatten_dim_num + 1):
|
||||
self._del_dim(node_idx, -1)
|
||||
self._add_dim(node_idx, -1)
|
||||
|
||||
def _assign_expand_indice(self, node: Node, node_idx: int) -> None:
|
||||
"""
|
||||
Assign indice for expand op.
|
||||
|
||||
Args:
|
||||
node (node)
|
||||
node_idx (int)
|
||||
"""
|
||||
expand_shape = node.args[1:]
|
||||
node_in_shape = get_node_shape(node.args[0])
|
||||
assert len(expand_shape) == len(node_in_shape)
|
||||
self._assign_indice_as_input(node, node_idx)
|
||||
for i in range(len(node_in_shape)):
|
||||
if expand_shape[i] == node_in_shape[i] or expand_shape[i] == -1:
|
||||
continue
|
||||
elif expand_shape[i] > node_in_shape[i]:
|
||||
self._del_dim(node_idx, i)
|
||||
self._add_dim(node_idx, i)
|
||||
else:
|
||||
raise RuntimeError()
|
||||
|
||||
def _assign_unbind_indice(self, node: Node, node_idx: int) -> None:
|
||||
"""
|
||||
Assign indice for unbind op.
|
||||
|
||||
Args:
|
||||
node (node)
|
||||
node_idx (int)
|
||||
"""
|
||||
unbind_dim = node.args[1]
|
||||
self._add_dim(node_idx, unbind_dim)
|
||||
self._assign_indice_as_input(node, node_idx)
|
||||
self._del_dim(node_idx, unbind_dim)
|
||||
|
||||
def _assign_embedding_indice(self, node: Node, node_idx: int) -> None:
|
||||
"""
|
||||
Assign indice for embedding op.
|
||||
@@ -695,32 +752,39 @@ class TraceIndice(object):
|
||||
shape_idx = target_shape.index(-1)
|
||||
target_shape[shape_idx] = origin_product // target_product
|
||||
|
||||
# determine changed dim
|
||||
len_diff = len(origin_shape) - len(target_shape)
|
||||
if len_diff == 1:
|
||||
# find same dim
|
||||
dim_to_same_dim = []
|
||||
dim_from_same_dim = []
|
||||
for i in range(len(origin_shape)):
|
||||
if origin_shape[i] == target_shape[i]:
|
||||
dim_to_same_dim.append(i)
|
||||
dim_from_same_dim.append(i)
|
||||
else:
|
||||
break
|
||||
for i in range(-1, -len(origin_shape), -1):
|
||||
if origin_shape[i] == target_shape[i]:
|
||||
dim_to_same_dim.append(len(target_shape) + i)
|
||||
dim_from_same_dim.append(len(origin_shape) + i)
|
||||
else:
|
||||
break
|
||||
|
||||
dim_from = list(set(range(len(origin_shape))) - set(dim_from_same_dim))
|
||||
dim_to = list(set(range(len(target_shape))) - set(dim_to_same_dim))
|
||||
assert len(dim_from) == 1 or len(dim_to) == 1 or len(dim_from) == len(dim_to)
|
||||
|
||||
dim_diff = len(dim_from) - len(dim_to)
|
||||
if dim_diff > 0:
|
||||
# dim merge
|
||||
dim_equal = [i == j for i, j in zip(origin_shape[:-1], target_shape)]
|
||||
dim_to = [dim_equal.index(False)]
|
||||
dim_from = [dim_equal.index(False), dim_equal.index(False) + 1]
|
||||
self._add_dim(node_idx, -1)
|
||||
elif len_diff == -1:
|
||||
for i in range(dim_diff):
|
||||
self._add_dim(node_idx, -1)
|
||||
elif dim_diff < 0:
|
||||
# dim expand
|
||||
dim_equal = [i == j for i, j in zip(origin_shape, target_shape[:-1])]
|
||||
dim_from = [dim_equal.index(False)]
|
||||
dim_to = [dim_equal.index(False), dim_equal.index(False) + 1]
|
||||
self._del_dim(node_idx, -1)
|
||||
elif len_diff == 0:
|
||||
# dim equal
|
||||
dim_equal = [i == j for i, j in zip(origin_shape, target_shape[:-1])]
|
||||
dim_from = []
|
||||
dim_to = []
|
||||
else:
|
||||
raise NotImplementedError("shape" + str(origin_shape) + "and" + str(target_shape) + "view not implemented")
|
||||
for i in range(-dim_diff):
|
||||
self._del_dim(node_idx, -1)
|
||||
|
||||
# get new indice
|
||||
origin_trace = self._find_indice_trace_from_node(origin_node)
|
||||
self._assign_indice_as_input(node, node_idx, origin_node)
|
||||
idx_from = [origin_trace[i] for i in dim_from]
|
||||
dim_from.reverse()
|
||||
for i in dim_from:
|
||||
self._del_dim(node_idx, i)
|
||||
@@ -728,36 +792,18 @@ class TraceIndice(object):
|
||||
self._add_dim(node_idx, i)
|
||||
dim_from.reverse()
|
||||
|
||||
# search view list
|
||||
# for view_node, view_dict in self.indice_view_list.items():
|
||||
# if (view_dict["idx_to"] == idx_from and view_dict["dim_to"] == dim_from
|
||||
# and view_dict["dim_from"] == dim_to):
|
||||
# # inheirt indice from current node
|
||||
# if len_diff == 1:
|
||||
# if origin_shape[dim_from[0]] == 1:
|
||||
# self._inherit_indice(origin_node, dim_from[1], node, dim_to[0], init=False)
|
||||
# elif origin_shape[dim_from[1]] == 1:
|
||||
# self._inherit_indice(origin_node, dim_from[0], node, dim_to[0], init=False)
|
||||
# elif len_diff == -1:
|
||||
# if target_shape[dim_to[0]] == 1:
|
||||
# self._inherit_indice(origin_node, dim_from[0], node, dim_to[1], init=False)
|
||||
# elif target_shape[dim_to[1]] == 1:
|
||||
# self._inherit_indice(origin_node, dim_from[0], node, dim_to[0], init=False)
|
||||
# # inherid indice from input node of last view
|
||||
# for dim_to_i in dim_to:
|
||||
# self._inherit_indice(view_node.args[0], dim_to_i, node, dim_to_i, init=False)
|
||||
|
||||
# inheirt indice from current node
|
||||
if len_diff == 1:
|
||||
if origin_shape[dim_from[0]] == 1:
|
||||
self._inherit_indice(origin_node, dim_from[1], node, dim_to[0], init=False)
|
||||
elif origin_shape[dim_from[1]] == 1:
|
||||
self._inherit_indice(origin_node, dim_from[0], node, dim_to[0], init=False)
|
||||
elif len_diff == -1:
|
||||
if target_shape[dim_to[0]] == 1:
|
||||
self._inherit_indice(origin_node, dim_from[0], node, dim_to[1], init=False)
|
||||
elif target_shape[dim_to[1]] == 1:
|
||||
self._inherit_indice(origin_node, dim_from[0], node, dim_to[0], init=False)
|
||||
if len(dim_from) != 0 and len(dim_to) != 0:
|
||||
if dim_diff == 1:
|
||||
if origin_shape[dim_from[0]] == 1:
|
||||
self._inherit_indice(origin_node, dim_from[1], node, dim_to[0], init=False)
|
||||
elif origin_shape[dim_from[1]] == 1:
|
||||
self._inherit_indice(origin_node, dim_from[0], node, dim_to[0], init=False)
|
||||
elif dim_diff == -1:
|
||||
if target_shape[dim_to[0]] == 1:
|
||||
self._inherit_indice(origin_node, dim_from[0], node, dim_to[1], init=False)
|
||||
elif target_shape[dim_to[1]] == 1:
|
||||
self._inherit_indice(origin_node, dim_from[0], node, dim_to[0], init=False)
|
||||
|
||||
# log view, not used now
|
||||
view_dict = {
|
||||
@@ -809,6 +855,14 @@ class TraceIndice(object):
|
||||
self._assgin_no_change_indice(node, idx)
|
||||
elif "new_ones" == node_name:
|
||||
self._assign_all_indice(node, idx)
|
||||
elif "flatten" == node_name:
|
||||
self._assign_flatten_indice(node, idx)
|
||||
elif "expand" == node_name:
|
||||
self._assign_expand_indice(node, idx)
|
||||
elif "unbind" == node_name:
|
||||
self._assign_unbind_indice(node, idx)
|
||||
elif "softmax" == node_name:
|
||||
self._assign_softmax_indice(node, idx)
|
||||
elif any(i == node_name for i in ["size"]):
|
||||
continue
|
||||
else:
|
||||
@@ -859,7 +913,9 @@ class TraceIndice(object):
|
||||
self._assign_linear_indice(node, idx)
|
||||
elif "conv2d" == node_name:
|
||||
self._assign_conv2d_indice(node, idx)
|
||||
elif any(n == node_name for n in ["sigmoid", "dropout", "relu", "silu"]):
|
||||
elif "identity" == node_name:
|
||||
self._assgin_no_change_indice(node, idx)
|
||||
elif any(n == node_name for n in ["sigmoid", "dropout", "relu", "silu", "gelu"]):
|
||||
self._assign_elementwise_indice(node, idx)
|
||||
else:
|
||||
raise NotImplementedError(node_name, "module not implemented yet!")
|
||||
|
Reference in New Issue
Block a user