mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-15 06:00:07 +00:00
[autochunk] support diffusion for autochunk (#2621)
* add alphafold benchmark * renae alphafold test * rename tests * rename diffuser * renme * rename * update transformer * update benchmark * update benchmark * update bench memory * update transformer benchmark * rename * support diffuser * support unet metainfo prop * fix bug and simplify code * update linear and support some op * optimize max region search, support conv * update unet test * support some op * support groupnorm and interpolate * update flow search * add fix dim in node flow * fix utils * rename * support diffusion * update diffuser * update chunk search * optimize imports * import * finish autochunk
This commit is contained in:
@@ -150,7 +150,7 @@ class TraceIndice(object):
|
||||
for i in range(len(node_from_indice)):
|
||||
self._inherit_indice(node_from, i, node_to, i, init=True)
|
||||
|
||||
def _inherit_more_indice_from_node(self, node_from: Node, node_to: Node, exclude: List = None) -> None:
|
||||
def _inherit_more_indice_from_node_with_exclude(self, node_from: Node, node_to: Node, exclude: List = None) -> None:
|
||||
"""
|
||||
inheirt indice from node without init
|
||||
"""
|
||||
@@ -308,14 +308,14 @@ class TraceIndice(object):
|
||||
node (node)
|
||||
node_idx (int)
|
||||
"""
|
||||
if len(node.args) == 2:
|
||||
_, weight = node.args
|
||||
else:
|
||||
_, weight, _ = node.args
|
||||
|
||||
self._assign_indice_as_input(node, node_idx)
|
||||
self._inherit_indice(weight, 1, node, -1)
|
||||
|
||||
if len(node.args) >= 2:
|
||||
weight = node.args[1]
|
||||
self._inherit_indice(weight, 1, node, -1)
|
||||
else:
|
||||
self._del_dim(node_idx, -1)
|
||||
self._add_dim(node_idx, -1)
|
||||
self._mark_computation(node, node_idx, [-1])
|
||||
|
||||
def _assign_addmm_indice(self, node: Node, node_idx: int) -> None:
|
||||
@@ -327,13 +327,35 @@ class TraceIndice(object):
|
||||
node_idx (int)
|
||||
"""
|
||||
bias, input_node, weight = node.args
|
||||
|
||||
assert len(get_node_shape(bias)) == 1 and len(get_node_shape(weight)) == 2
|
||||
self._assign_indice_as_input(node, node_idx, input_node)
|
||||
self._inherit_indice(weight, 1, node, -1)
|
||||
self._inherit_indice(bias, -1, node, -1)
|
||||
self._inherit_more_indice_from_node_with_exclude(bias, node)
|
||||
|
||||
self._mark_computation(node, node_idx, [-1])
|
||||
|
||||
def _assign_baddbmm_indice(self, node: Node, node_idx: int) -> None:
|
||||
"""
|
||||
Assign indice for baddbmm(batch add and batch matmul) op.
|
||||
add, matmul_left, matmul_right = args
|
||||
out = add + (matmul_left x matmul_right)
|
||||
|
||||
Args:
|
||||
node (node)
|
||||
node_idx (int)
|
||||
"""
|
||||
add, matmul_left, matmul_right = node.args
|
||||
|
||||
assert get_node_shape(add) == get_node_shape(node)
|
||||
assert len(get_node_shape(matmul_left)) == len(get_node_shape(matmul_right))
|
||||
self._assign_indice_as_input(node, node_idx, matmul_left)
|
||||
# matmul
|
||||
self._inherit_indice(matmul_right, -1, node, -1)
|
||||
self._inherit_more_indice_from_node_with_exclude(matmul_right, node, [-2, -1])
|
||||
self._mark_computation(node, node_idx, [-1])
|
||||
# add
|
||||
self._inherit_more_indice_from_node_with_exclude(add, node)
|
||||
|
||||
def _assign_matmul_indice(self, node: Node, node_idx: int) -> None:
|
||||
"""
|
||||
Assign indice for matmul op.
|
||||
@@ -349,11 +371,53 @@ class TraceIndice(object):
|
||||
|
||||
assert len(get_node_shape(matmul_left)) == len(get_node_shape(matmul_right))
|
||||
self._assign_indice_as_input(node, node_idx, matmul_left)
|
||||
self._inherit_indice(matmul_right, -1, node, -1)
|
||||
|
||||
self._inherit_more_indice_from_node(matmul_right, node, [-1, -2])
|
||||
self._inherit_indice(matmul_right, -1, node, -1)
|
||||
self._inherit_more_indice_from_node_with_exclude(matmul_right, node, [-1, -2])
|
||||
self._mark_computation(node, node_idx, [-1])
|
||||
|
||||
def _assign_conv2d_indice(self, node: Node, node_idx: int) -> None:
|
||||
"""
|
||||
Assign indice for conv2d op.
|
||||
|
||||
Args:
|
||||
node (node)
|
||||
node_idx (int)
|
||||
"""
|
||||
# get conv module
|
||||
node_targets = node.target.split(".")
|
||||
conv_module = node.graph.owning_module
|
||||
for i in node_targets:
|
||||
conv_module = getattr(conv_module, i)
|
||||
assert conv_module.dilation == (1, 1), "dilation for conv2d not implemented"
|
||||
|
||||
# get conv input
|
||||
assert len(node.args) == 1
|
||||
input_node = node.args[0]
|
||||
assert len(get_node_shape(input_node)) == 4
|
||||
|
||||
# assgin index
|
||||
self._assign_indice_as_input(node, node_idx, input_node)
|
||||
self._del_dim(node_idx, 1)
|
||||
self._add_dim(node_idx, 1)
|
||||
self._mark_computation(node, node_idx, [1, 2, 3])
|
||||
|
||||
def _assign_interpolate_indice(self, node: Node, node_idx: int) -> None:
|
||||
"""
|
||||
Assign indice for interpolate op.
|
||||
|
||||
Args:
|
||||
node (node)
|
||||
node_idx (int)
|
||||
"""
|
||||
# get conv input
|
||||
assert node.kwargs['size'] is None
|
||||
assert len(get_node_shape(node)) == 4
|
||||
|
||||
# assgin index
|
||||
self._assign_indice_as_input(node, node_idx)
|
||||
self._mark_computation(node, node_idx, [-1, -2])
|
||||
|
||||
def _assign_layernorm_indice(self, node, idx):
|
||||
"""
|
||||
Assign indice for layernorm op.
|
||||
@@ -367,6 +431,18 @@ class TraceIndice(object):
|
||||
self._assign_indice_as_input(node, idx)
|
||||
self._mark_computation(node, idx, [-1])
|
||||
|
||||
def _assign_groupnorm_indice(self, node, idx):
|
||||
"""
|
||||
Assign indice for groupnorm op.
|
||||
|
||||
Args:
|
||||
node (node)
|
||||
node_idx (int)
|
||||
"""
|
||||
assert len(get_node_shape(node)) == 4
|
||||
self._assign_indice_as_input(node, idx)
|
||||
self._mark_computation(node, idx, [-1, -2, -3])
|
||||
|
||||
def _assign_elementwise_indice(self, node, idx):
|
||||
"""
|
||||
Assign indice for element-wise op (eg. relu sigmoid add mul).
|
||||
@@ -382,13 +458,13 @@ class TraceIndice(object):
|
||||
for node_in in node.args:
|
||||
if type(node_in) == type(node):
|
||||
nodes_in.append(node_in)
|
||||
self._inherit_more_indice_from_node(node_in, node)
|
||||
self._inherit_more_indice_from_node_with_exclude(node_in, node)
|
||||
|
||||
def _assgin_no_change_indice(self, node, idx):
|
||||
self._assign_indice_as_input(node, idx)
|
||||
for node_in in node.args:
|
||||
if type(node_in) == type(node):
|
||||
self._inherit_more_indice_from_node(node_in, node)
|
||||
self._inherit_more_indice_from_node_with_exclude(node_in, node)
|
||||
|
||||
def _assign_einsum_indice(self, node, idx):
|
||||
"""
|
||||
@@ -469,17 +545,6 @@ class TraceIndice(object):
|
||||
dim_idx = list(range(len(get_node_shape(node))))[dim_idx]
|
||||
self._add_dim(node_idx, dim_idx)
|
||||
|
||||
def _assign_ones_like_indice(self, node: Node, node_idx: int) -> None:
|
||||
"""
|
||||
Assign indice for oneslike op.
|
||||
1. assign new indice for all dim
|
||||
|
||||
Args:
|
||||
node (node)
|
||||
node_idx (int)
|
||||
"""
|
||||
self._assign_all_indice(node, node_idx)
|
||||
|
||||
def _assign_cat_indice(self, node: Node, node_idx: int) -> None:
|
||||
"""
|
||||
Assign indice for cat op.
|
||||
@@ -491,7 +556,7 @@ class TraceIndice(object):
|
||||
nodes_in = flat_list(node.args[0])
|
||||
self._assign_indice_as_input(node, node_idx, input_node=nodes_in[0])
|
||||
for n in nodes_in[1:]:
|
||||
self._inherit_more_indice_from_node(n, node)
|
||||
self._inherit_more_indice_from_node_with_exclude(n, node)
|
||||
cat_dim = node.kwargs["dim"]
|
||||
self._del_dim(node_idx, cat_dim)
|
||||
self._add_dim(node_idx, cat_dim)
|
||||
@@ -508,33 +573,10 @@ class TraceIndice(object):
|
||||
self._add_dim(node_idx, 0)
|
||||
self._assign_indice_as_input(node, node_idx, input_node=nodes_in[0])
|
||||
for n in nodes_in[1:]:
|
||||
self._inherit_more_indice_from_node(n, node)
|
||||
self._inherit_more_indice_from_node_with_exclude(n, node)
|
||||
cat_dim = node.kwargs["dim"]
|
||||
self._del_dim(node_idx, cat_dim)
|
||||
|
||||
def _assign_arange_indice(self, node: Node, node_idx: int) -> None:
|
||||
"""
|
||||
Assign indice for arange op.
|
||||
|
||||
Args:
|
||||
node (node)
|
||||
node_idx (int)
|
||||
"""
|
||||
self._assign_all_indice(node, node_idx)
|
||||
|
||||
def _assign_tensor_indice(self, node: Node, node_idx: int) -> None:
|
||||
"""
|
||||
Assign indice for tensor op.
|
||||
|
||||
Args:
|
||||
node (node)
|
||||
node_idx (int)
|
||||
"""
|
||||
if len(get_node_shape(node)) == 0:
|
||||
return
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
def _assign_embedding_indice(self, node: Node, node_idx: int) -> None:
|
||||
"""
|
||||
Assign indice for embedding op.
|
||||
@@ -763,10 +805,10 @@ class TraceIndice(object):
|
||||
self._assign_unsqueeze_indice(node, idx)
|
||||
elif "split" == node_name:
|
||||
self._assign_split_indice(node, idx)
|
||||
elif any(i == node_name for i in ["to", "contiguous", "clone", "type"]):
|
||||
elif any(i == node_name for i in ["to", "contiguous", "clone", "type", "float"]):
|
||||
self._assgin_no_change_indice(node, idx)
|
||||
elif "new_ones" == node_name:
|
||||
self._assign_ones_like_indice(node, idx)
|
||||
self._assign_all_indice(node, idx)
|
||||
elif any(i == node_name for i in ["size"]):
|
||||
continue
|
||||
else:
|
||||
@@ -776,25 +818,15 @@ class TraceIndice(object):
|
||||
self._assign_linear_indice(node, idx)
|
||||
elif "cat" == node_name:
|
||||
self._assign_cat_indice(node, idx)
|
||||
elif "matmul" == node_name:
|
||||
elif any(n == node_name for n in ["matmul", "bmm"]):
|
||||
self._assign_matmul_indice(node, idx)
|
||||
elif "softmax" == node_name:
|
||||
self._assign_softmax_indice(node, idx)
|
||||
elif any(n == node_name for n in [
|
||||
"mul",
|
||||
"add",
|
||||
"sigmoid",
|
||||
"relu",
|
||||
"sub",
|
||||
"truediv",
|
||||
"pow",
|
||||
"dropout",
|
||||
"where",
|
||||
"tanh",
|
||||
"mul", "add", "sigmoid", "relu", "sub", "truediv", "pow", "dropout", "where", "tanh", "exp",
|
||||
"sin", "cos"
|
||||
]):
|
||||
self._assign_elementwise_indice(node, idx)
|
||||
elif "ones_like" == node_name:
|
||||
self._assign_ones_like_indice(node, idx)
|
||||
elif "einsum" == node_name:
|
||||
self._assign_einsum_indice(node, idx)
|
||||
elif "sum" == node_name:
|
||||
@@ -805,10 +837,12 @@ class TraceIndice(object):
|
||||
self._assign_getitem_indice(node, idx)
|
||||
elif "addmm" == node_name:
|
||||
self._assign_addmm_indice(node, idx)
|
||||
elif "arange" == node_name:
|
||||
self._assign_arange_indice(node, idx)
|
||||
elif "tensor" == node_name:
|
||||
self._assign_arange_indice(node, idx)
|
||||
elif "baddbmm" == node_name:
|
||||
self._assign_baddbmm_indice(node, idx)
|
||||
elif "interpolate" == node_name:
|
||||
self._assign_interpolate_indice(node, idx)
|
||||
elif any(i == node_name for i in ["arange", "ones", "ones_like", "tensor", "empty"]):
|
||||
self._assign_all_indice(node, idx)
|
||||
elif any(i == node_name for i in ["getattr", "eq", "_assert_is_none", "_assert", "finfo"]):
|
||||
continue
|
||||
else:
|
||||
@@ -817,9 +851,15 @@ class TraceIndice(object):
|
||||
node_name = get_module_node_name(node)
|
||||
if "layernorm" == node_name:
|
||||
self._assign_layernorm_indice(node, idx)
|
||||
elif "groupnorm" == node_name:
|
||||
self._assign_groupnorm_indice(node, idx)
|
||||
elif "embedding" == node_name:
|
||||
self._assign_embedding_indice(node, idx)
|
||||
elif any(n == node_name for n in ["sigmoid", "dropout", "relu"]):
|
||||
elif "linear" == node_name:
|
||||
self._assign_linear_indice(node, idx)
|
||||
elif "conv2d" == node_name:
|
||||
self._assign_conv2d_indice(node, idx)
|
||||
elif any(n == node_name for n in ["sigmoid", "dropout", "relu", "silu"]):
|
||||
self._assign_elementwise_indice(node, idx)
|
||||
else:
|
||||
raise NotImplementedError(node_name, "module not implemented yet!")
|
||||
|
Reference in New Issue
Block a user