[autochunk] support diffusion for autochunk (#2621)

* add alphafold benchmark

* renae alphafold test

* rename tests

* rename diffuser

* renme

* rename

* update transformer

* update benchmark

* update benchmark

* update bench memory

* update transformer benchmark

* rename

* support diffuser

* support unet metainfo prop

* fix bug and simplify code

* update linear and support some op

* optimize max region search, support conv

* update unet test

* support some op

* support groupnorm and interpolate

* update flow search

* add fix dim in node flow

* fix utils

* rename

* support diffusion

* update diffuser

* update chunk search

* optimize imports

* import

* finish autochunk
This commit is contained in:
oahzxl
2023-02-07 16:32:45 +08:00
committed by GitHub
parent 291b051171
commit 6ba8364881
6 changed files with 216 additions and 166 deletions

View File

@@ -150,7 +150,7 @@ class TraceIndice(object):
for i in range(len(node_from_indice)):
self._inherit_indice(node_from, i, node_to, i, init=True)
def _inherit_more_indice_from_node(self, node_from: Node, node_to: Node, exclude: List = None) -> None:
def _inherit_more_indice_from_node_with_exclude(self, node_from: Node, node_to: Node, exclude: List = None) -> None:
"""
inheirt indice from node without init
"""
@@ -308,14 +308,14 @@ class TraceIndice(object):
node (node)
node_idx (int)
"""
if len(node.args) == 2:
_, weight = node.args
else:
_, weight, _ = node.args
self._assign_indice_as_input(node, node_idx)
self._inherit_indice(weight, 1, node, -1)
if len(node.args) >= 2:
weight = node.args[1]
self._inherit_indice(weight, 1, node, -1)
else:
self._del_dim(node_idx, -1)
self._add_dim(node_idx, -1)
self._mark_computation(node, node_idx, [-1])
def _assign_addmm_indice(self, node: Node, node_idx: int) -> None:
@@ -327,13 +327,35 @@ class TraceIndice(object):
node_idx (int)
"""
bias, input_node, weight = node.args
assert len(get_node_shape(bias)) == 1 and len(get_node_shape(weight)) == 2
self._assign_indice_as_input(node, node_idx, input_node)
self._inherit_indice(weight, 1, node, -1)
self._inherit_indice(bias, -1, node, -1)
self._inherit_more_indice_from_node_with_exclude(bias, node)
self._mark_computation(node, node_idx, [-1])
def _assign_baddbmm_indice(self, node: Node, node_idx: int) -> None:
"""
Assign indice for baddbmm(batch add and batch matmul) op.
add, matmul_left, matmul_right = args
out = add + (matmul_left x matmul_right)
Args:
node (node)
node_idx (int)
"""
add, matmul_left, matmul_right = node.args
assert get_node_shape(add) == get_node_shape(node)
assert len(get_node_shape(matmul_left)) == len(get_node_shape(matmul_right))
self._assign_indice_as_input(node, node_idx, matmul_left)
# matmul
self._inherit_indice(matmul_right, -1, node, -1)
self._inherit_more_indice_from_node_with_exclude(matmul_right, node, [-2, -1])
self._mark_computation(node, node_idx, [-1])
# add
self._inherit_more_indice_from_node_with_exclude(add, node)
def _assign_matmul_indice(self, node: Node, node_idx: int) -> None:
"""
Assign indice for matmul op.
@@ -349,11 +371,53 @@ class TraceIndice(object):
assert len(get_node_shape(matmul_left)) == len(get_node_shape(matmul_right))
self._assign_indice_as_input(node, node_idx, matmul_left)
self._inherit_indice(matmul_right, -1, node, -1)
self._inherit_more_indice_from_node(matmul_right, node, [-1, -2])
self._inherit_indice(matmul_right, -1, node, -1)
self._inherit_more_indice_from_node_with_exclude(matmul_right, node, [-1, -2])
self._mark_computation(node, node_idx, [-1])
def _assign_conv2d_indice(self, node: Node, node_idx: int) -> None:
"""
Assign indice for conv2d op.
Args:
node (node)
node_idx (int)
"""
# get conv module
node_targets = node.target.split(".")
conv_module = node.graph.owning_module
for i in node_targets:
conv_module = getattr(conv_module, i)
assert conv_module.dilation == (1, 1), "dilation for conv2d not implemented"
# get conv input
assert len(node.args) == 1
input_node = node.args[0]
assert len(get_node_shape(input_node)) == 4
# assgin index
self._assign_indice_as_input(node, node_idx, input_node)
self._del_dim(node_idx, 1)
self._add_dim(node_idx, 1)
self._mark_computation(node, node_idx, [1, 2, 3])
def _assign_interpolate_indice(self, node: Node, node_idx: int) -> None:
"""
Assign indice for interpolate op.
Args:
node (node)
node_idx (int)
"""
# get conv input
assert node.kwargs['size'] is None
assert len(get_node_shape(node)) == 4
# assgin index
self._assign_indice_as_input(node, node_idx)
self._mark_computation(node, node_idx, [-1, -2])
def _assign_layernorm_indice(self, node, idx):
"""
Assign indice for layernorm op.
@@ -367,6 +431,18 @@ class TraceIndice(object):
self._assign_indice_as_input(node, idx)
self._mark_computation(node, idx, [-1])
def _assign_groupnorm_indice(self, node, idx):
"""
Assign indice for groupnorm op.
Args:
node (node)
node_idx (int)
"""
assert len(get_node_shape(node)) == 4
self._assign_indice_as_input(node, idx)
self._mark_computation(node, idx, [-1, -2, -3])
def _assign_elementwise_indice(self, node, idx):
"""
Assign indice for element-wise op (eg. relu sigmoid add mul).
@@ -382,13 +458,13 @@ class TraceIndice(object):
for node_in in node.args:
if type(node_in) == type(node):
nodes_in.append(node_in)
self._inherit_more_indice_from_node(node_in, node)
self._inherit_more_indice_from_node_with_exclude(node_in, node)
def _assgin_no_change_indice(self, node, idx):
self._assign_indice_as_input(node, idx)
for node_in in node.args:
if type(node_in) == type(node):
self._inherit_more_indice_from_node(node_in, node)
self._inherit_more_indice_from_node_with_exclude(node_in, node)
def _assign_einsum_indice(self, node, idx):
"""
@@ -469,17 +545,6 @@ class TraceIndice(object):
dim_idx = list(range(len(get_node_shape(node))))[dim_idx]
self._add_dim(node_idx, dim_idx)
def _assign_ones_like_indice(self, node: Node, node_idx: int) -> None:
"""
Assign indice for oneslike op.
1. assign new indice for all dim
Args:
node (node)
node_idx (int)
"""
self._assign_all_indice(node, node_idx)
def _assign_cat_indice(self, node: Node, node_idx: int) -> None:
"""
Assign indice for cat op.
@@ -491,7 +556,7 @@ class TraceIndice(object):
nodes_in = flat_list(node.args[0])
self._assign_indice_as_input(node, node_idx, input_node=nodes_in[0])
for n in nodes_in[1:]:
self._inherit_more_indice_from_node(n, node)
self._inherit_more_indice_from_node_with_exclude(n, node)
cat_dim = node.kwargs["dim"]
self._del_dim(node_idx, cat_dim)
self._add_dim(node_idx, cat_dim)
@@ -508,33 +573,10 @@ class TraceIndice(object):
self._add_dim(node_idx, 0)
self._assign_indice_as_input(node, node_idx, input_node=nodes_in[0])
for n in nodes_in[1:]:
self._inherit_more_indice_from_node(n, node)
self._inherit_more_indice_from_node_with_exclude(n, node)
cat_dim = node.kwargs["dim"]
self._del_dim(node_idx, cat_dim)
def _assign_arange_indice(self, node: Node, node_idx: int) -> None:
"""
Assign indice for arange op.
Args:
node (node)
node_idx (int)
"""
self._assign_all_indice(node, node_idx)
def _assign_tensor_indice(self, node: Node, node_idx: int) -> None:
"""
Assign indice for tensor op.
Args:
node (node)
node_idx (int)
"""
if len(get_node_shape(node)) == 0:
return
else:
raise NotImplementedError()
def _assign_embedding_indice(self, node: Node, node_idx: int) -> None:
"""
Assign indice for embedding op.
@@ -763,10 +805,10 @@ class TraceIndice(object):
self._assign_unsqueeze_indice(node, idx)
elif "split" == node_name:
self._assign_split_indice(node, idx)
elif any(i == node_name for i in ["to", "contiguous", "clone", "type"]):
elif any(i == node_name for i in ["to", "contiguous", "clone", "type", "float"]):
self._assgin_no_change_indice(node, idx)
elif "new_ones" == node_name:
self._assign_ones_like_indice(node, idx)
self._assign_all_indice(node, idx)
elif any(i == node_name for i in ["size"]):
continue
else:
@@ -776,25 +818,15 @@ class TraceIndice(object):
self._assign_linear_indice(node, idx)
elif "cat" == node_name:
self._assign_cat_indice(node, idx)
elif "matmul" == node_name:
elif any(n == node_name for n in ["matmul", "bmm"]):
self._assign_matmul_indice(node, idx)
elif "softmax" == node_name:
self._assign_softmax_indice(node, idx)
elif any(n == node_name for n in [
"mul",
"add",
"sigmoid",
"relu",
"sub",
"truediv",
"pow",
"dropout",
"where",
"tanh",
"mul", "add", "sigmoid", "relu", "sub", "truediv", "pow", "dropout", "where", "tanh", "exp",
"sin", "cos"
]):
self._assign_elementwise_indice(node, idx)
elif "ones_like" == node_name:
self._assign_ones_like_indice(node, idx)
elif "einsum" == node_name:
self._assign_einsum_indice(node, idx)
elif "sum" == node_name:
@@ -805,10 +837,12 @@ class TraceIndice(object):
self._assign_getitem_indice(node, idx)
elif "addmm" == node_name:
self._assign_addmm_indice(node, idx)
elif "arange" == node_name:
self._assign_arange_indice(node, idx)
elif "tensor" == node_name:
self._assign_arange_indice(node, idx)
elif "baddbmm" == node_name:
self._assign_baddbmm_indice(node, idx)
elif "interpolate" == node_name:
self._assign_interpolate_indice(node, idx)
elif any(i == node_name for i in ["arange", "ones", "ones_like", "tensor", "empty"]):
self._assign_all_indice(node, idx)
elif any(i == node_name for i in ["getattr", "eq", "_assert_is_none", "_assert", "finfo"]):
continue
else:
@@ -817,9 +851,15 @@ class TraceIndice(object):
node_name = get_module_node_name(node)
if "layernorm" == node_name:
self._assign_layernorm_indice(node, idx)
elif "groupnorm" == node_name:
self._assign_groupnorm_indice(node, idx)
elif "embedding" == node_name:
self._assign_embedding_indice(node, idx)
elif any(n == node_name for n in ["sigmoid", "dropout", "relu"]):
elif "linear" == node_name:
self._assign_linear_indice(node, idx)
elif "conv2d" == node_name:
self._assign_conv2d_indice(node, idx)
elif any(n == node_name for n in ["sigmoid", "dropout", "relu", "silu"]):
self._assign_elementwise_indice(node, idx)
else:
raise NotImplementedError(node_name, "module not implemented yet!")