mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 03:52:01 +00:00
[autochunk] support diffusion for autochunk (#2621)
* add alphafold benchmark * renae alphafold test * rename tests * rename diffuser * renme * rename * update transformer * update benchmark * update benchmark * update bench memory * update transformer benchmark * rename * support diffuser * support unet metainfo prop * fix bug and simplify code * update linear and support some op * optimize max region search, support conv * update unet test * support some op * support groupnorm and interpolate * update flow search * add fix dim in node flow * fix utils * rename * support diffusion * update diffuser * update chunk search * optimize imports * import * finish autochunk
This commit is contained in:
@@ -100,6 +100,16 @@ class TraceFlow(object):
|
||||
if not (start_idx <= arg_idx < end_idx):
|
||||
return True
|
||||
|
||||
# get fix dim
|
||||
arg_fix_dim = []
|
||||
if cur_node_dim is not None:
|
||||
for i in cur_node_fix_dim:
|
||||
fix_dim_source = cur_node_source[i]
|
||||
if arg_idx in fix_dim_source:
|
||||
arg_fix_dim.append(fix_dim_source[arg_idx][0])
|
||||
if arg_node in all_node_info:
|
||||
arg_fix_dim = list(set(all_node_info[arg_node]["fix_dim"] + arg_fix_dim))
|
||||
|
||||
# find arg dim
|
||||
if cur_node_dim is not None:
|
||||
# dim is computed
|
||||
@@ -109,6 +119,9 @@ class TraceFlow(object):
|
||||
arg_dim = None
|
||||
else:
|
||||
arg_dim = cur_node_source[cur_node_dim][arg_idx][0]
|
||||
# chunk dim cannot be in fix dims
|
||||
if arg_dim in arg_fix_dim:
|
||||
return False
|
||||
# chunk dim should be None if shape size is 1
|
||||
if get_node_shape(arg_node)[arg_dim] == 1:
|
||||
arg_dim = None
|
||||
@@ -120,19 +133,16 @@ class TraceFlow(object):
|
||||
else:
|
||||
arg_dim = None
|
||||
|
||||
# get fix dim
|
||||
arg_fix_dim = []
|
||||
if cur_node_dim is not None:
|
||||
for i in cur_node_fix_dim:
|
||||
fix_dim_source = cur_node_source[i]
|
||||
if arg_idx in fix_dim_source:
|
||||
arg_fix_dim.append(fix_dim_source[arg_idx][0])
|
||||
# add arg rest dim as fix dim
|
||||
arg_fix_dim = list(range(len(get_node_shape(arg_node))))
|
||||
if arg_dim is not None:
|
||||
arg_fix_dim.remove(arg_dim)
|
||||
|
||||
# if already in node_info, arg dim must be same
|
||||
if arg_node in all_node_info:
|
||||
if all_node_info[arg_node]["chunk_dim"] != arg_dim:
|
||||
return False
|
||||
all_node_info[arg_node]["fix_dim"] = list(set(all_node_info[arg_node]["fix_dim"] + arg_fix_dim))
|
||||
all_node_info[arg_node]["fix_dim"] = arg_fix_dim
|
||||
# else add it to list
|
||||
else:
|
||||
all_node_info[arg_node] = {"chunk_dim": arg_dim, "fix_dim": arg_fix_dim}
|
||||
@@ -164,6 +174,8 @@ class TraceFlow(object):
|
||||
continue
|
||||
if is_non_compute_node(arg):
|
||||
continue
|
||||
if get_node_shape(arg) is None:
|
||||
continue
|
||||
arg_list.append(arg)
|
||||
flow_flag = self._assgin_single_node_flow(
|
||||
arg,
|
||||
@@ -180,29 +192,6 @@ class TraceFlow(object):
|
||||
if flow_flag == False:
|
||||
return None
|
||||
|
||||
if len(arg_list) >= 2:
|
||||
# need to mark fix dim
|
||||
if any(i == get_node_name(cur_node) for i in ["add", "mul", "truediv", "sub", "where"]):
|
||||
for arg in arg_list:
|
||||
if get_node_shape(arg) is None:
|
||||
continue
|
||||
if not (start_idx <= self.node_mgr.find_node_idx(arg) < end_idx):
|
||||
continue
|
||||
arg_chunk_dim = all_node_info[arg]["chunk_dim"]
|
||||
arg_fix_dim = all_node_info[arg]["fix_dim"]
|
||||
arg_shape = get_node_shape(arg)
|
||||
# add all dim as fix dim except chunk dim
|
||||
for i, shape in enumerate(arg_shape):
|
||||
if shape != 1 and i != cur_node_chunk_dim:
|
||||
if i == arg_chunk_dim:
|
||||
return None
|
||||
if i not in arg_fix_dim:
|
||||
arg_fix_dim.append(i)
|
||||
elif any(i == get_node_name(cur_node)
|
||||
for i in ["einsum", "matmul", "view", "to", "getitem", "tensor", "type"]):
|
||||
pass
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
cur_node_list = next_node_list
|
||||
return all_node_info
|
||||
|
||||
|
Reference in New Issue
Block a user