[autochunk] support diffusion for autochunk (#2621)

* add alphafold benchmark

* renae alphafold test

* rename tests

* rename diffuser

* renme

* rename

* update transformer

* update benchmark

* update benchmark

* update bench memory

* update transformer benchmark

* rename

* support diffuser

* support unet metainfo prop

* fix bug and simplify code

* update linear and support some op

* optimize max region search, support conv

* update unet test

* support some op

* support groupnorm and interpolate

* update flow search

* add fix dim in node flow

* fix utils

* rename

* support diffusion

* update diffuser

* update chunk search

* optimize imports

* import

* finish autochunk
This commit is contained in:
oahzxl
2023-02-07 16:32:45 +08:00
committed by GitHub
parent 291b051171
commit 6ba8364881
6 changed files with 216 additions and 166 deletions

View File

@@ -100,6 +100,16 @@ class TraceFlow(object):
if not (start_idx <= arg_idx < end_idx):
return True
# get fix dim
arg_fix_dim = []
if cur_node_dim is not None:
for i in cur_node_fix_dim:
fix_dim_source = cur_node_source[i]
if arg_idx in fix_dim_source:
arg_fix_dim.append(fix_dim_source[arg_idx][0])
if arg_node in all_node_info:
arg_fix_dim = list(set(all_node_info[arg_node]["fix_dim"] + arg_fix_dim))
# find arg dim
if cur_node_dim is not None:
# dim is computed
@@ -109,6 +119,9 @@ class TraceFlow(object):
arg_dim = None
else:
arg_dim = cur_node_source[cur_node_dim][arg_idx][0]
# chunk dim cannot be in fix dims
if arg_dim in arg_fix_dim:
return False
# chunk dim should be None if shape size is 1
if get_node_shape(arg_node)[arg_dim] == 1:
arg_dim = None
@@ -120,19 +133,16 @@ class TraceFlow(object):
else:
arg_dim = None
# get fix dim
arg_fix_dim = []
if cur_node_dim is not None:
for i in cur_node_fix_dim:
fix_dim_source = cur_node_source[i]
if arg_idx in fix_dim_source:
arg_fix_dim.append(fix_dim_source[arg_idx][0])
# add arg rest dim as fix dim
arg_fix_dim = list(range(len(get_node_shape(arg_node))))
if arg_dim is not None:
arg_fix_dim.remove(arg_dim)
# if already in node_info, arg dim must be same
if arg_node in all_node_info:
if all_node_info[arg_node]["chunk_dim"] != arg_dim:
return False
all_node_info[arg_node]["fix_dim"] = list(set(all_node_info[arg_node]["fix_dim"] + arg_fix_dim))
all_node_info[arg_node]["fix_dim"] = arg_fix_dim
# else add it to list
else:
all_node_info[arg_node] = {"chunk_dim": arg_dim, "fix_dim": arg_fix_dim}
@@ -164,6 +174,8 @@ class TraceFlow(object):
continue
if is_non_compute_node(arg):
continue
if get_node_shape(arg) is None:
continue
arg_list.append(arg)
flow_flag = self._assgin_single_node_flow(
arg,
@@ -180,29 +192,6 @@ class TraceFlow(object):
if flow_flag == False:
return None
if len(arg_list) >= 2:
# need to mark fix dim
if any(i == get_node_name(cur_node) for i in ["add", "mul", "truediv", "sub", "where"]):
for arg in arg_list:
if get_node_shape(arg) is None:
continue
if not (start_idx <= self.node_mgr.find_node_idx(arg) < end_idx):
continue
arg_chunk_dim = all_node_info[arg]["chunk_dim"]
arg_fix_dim = all_node_info[arg]["fix_dim"]
arg_shape = get_node_shape(arg)
# add all dim as fix dim except chunk dim
for i, shape in enumerate(arg_shape):
if shape != 1 and i != cur_node_chunk_dim:
if i == arg_chunk_dim:
return None
if i not in arg_fix_dim:
arg_fix_dim.append(i)
elif any(i == get_node_name(cur_node)
for i in ["einsum", "matmul", "view", "to", "getitem", "tensor", "type"]):
pass
else:
raise NotImplementedError()
cur_node_list = next_node_list
return all_node_info