mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-21 10:50:56 +00:00
format code
This commit is contained in:
parent
d361d533e8
commit
ded1005667
184
chunk_codegen.py
184
chunk_codegen.py
@ -144,7 +144,9 @@ class IndexTracer(object):
|
|||||||
node_to_trace["source"][node_to_dim][node_from_idx] = [node_from_dim]
|
node_to_trace["source"][node_to_dim][node_from_idx] = [node_from_dim]
|
||||||
else:
|
else:
|
||||||
if node_from_dim not in node_to_trace["source"][node_to_dim][node_from_idx]:
|
if node_from_dim not in node_to_trace["source"][node_to_dim][node_from_idx]:
|
||||||
node_to_trace["source"][node_to_dim][node_from_idx].append(node_from_dim)
|
node_to_trace["source"][node_to_dim][node_from_idx].append(
|
||||||
|
node_from_dim
|
||||||
|
)
|
||||||
# update inputs source
|
# update inputs source
|
||||||
node_to_trace["source"][node_to_dim].update(
|
node_to_trace["source"][node_to_dim].update(
|
||||||
node_from_trace["source"][node_from_dim]
|
node_from_trace["source"][node_from_dim]
|
||||||
@ -745,7 +747,6 @@ class IndexTracer(object):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class FlowTracer(object):
|
class FlowTracer(object):
|
||||||
def __init__(self, gm) -> None:
|
def __init__(self, gm) -> None:
|
||||||
self.gm = gm
|
self.gm = gm
|
||||||
@ -856,7 +857,9 @@ class FlowTracer(object):
|
|||||||
)
|
)
|
||||||
return self.flow_trace
|
return self.flow_trace
|
||||||
|
|
||||||
def _detect_flow(self, start_idx, start_dim, end_idx, end_dim, index_tracer: IndexTracer):
|
def _detect_flow(
|
||||||
|
self, start_idx, start_dim, end_idx, end_dim, index_tracer: IndexTracer
|
||||||
|
):
|
||||||
inputs, outputs = _find_chunk_compute_input_and_output_nodes(
|
inputs, outputs = _find_chunk_compute_input_and_output_nodes(
|
||||||
self.node_list[start_idx : end_idx + 1]
|
self.node_list[start_idx : end_idx + 1]
|
||||||
)
|
)
|
||||||
@ -945,8 +948,10 @@ class FlowTracer(object):
|
|||||||
for i in remove_inputs:
|
for i in remove_inputs:
|
||||||
if i in chunk_info["inputs"]:
|
if i in chunk_info["inputs"]:
|
||||||
chunk_info["inputs"].remove(i)
|
chunk_info["inputs"].remove(i)
|
||||||
|
|
||||||
duplicate_result, duplicate_dim = index_tracer.check_index_duplicate(chunk_info, return_dim=True)
|
duplicate_result, duplicate_dim = index_tracer.check_index_duplicate(
|
||||||
|
chunk_info, return_dim=True
|
||||||
|
)
|
||||||
|
|
||||||
# we need to log input nodes to avoid deleteing them in the loop
|
# we need to log input nodes to avoid deleteing them in the loop
|
||||||
non_chunk_inputs = _find_chunk_all_input_nodes(
|
non_chunk_inputs = _find_chunk_all_input_nodes(
|
||||||
@ -958,15 +963,25 @@ class FlowTracer(object):
|
|||||||
|
|
||||||
return flow_block, chunk_info
|
return flow_block, chunk_info
|
||||||
|
|
||||||
def _assgin_single_node_flow(self, arg_node, start_idx, end_idx,
|
def _assgin_single_node_flow(
|
||||||
inputs, index_tracer, cur_node_dim,
|
self,
|
||||||
cur_node_compute, cur_node_source, cur_node_fix_dim, all_node_info,
|
arg_node,
|
||||||
next_node_list):
|
start_idx,
|
||||||
|
end_idx,
|
||||||
|
inputs,
|
||||||
|
index_tracer,
|
||||||
|
cur_node_dim,
|
||||||
|
cur_node_compute,
|
||||||
|
cur_node_source,
|
||||||
|
cur_node_fix_dim,
|
||||||
|
all_node_info,
|
||||||
|
next_node_list,
|
||||||
|
):
|
||||||
arg_idx = _find_idx_by_name(arg_node.name, index_tracer.nodes_list)
|
arg_idx = _find_idx_by_name(arg_node.name, index_tracer.nodes_list)
|
||||||
# arg in chunk range or be inputs
|
# arg in chunk range or be inputs
|
||||||
if not (start_idx <= arg_idx < end_idx):
|
if not (start_idx <= arg_idx < end_idx):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# find arg dim
|
# find arg dim
|
||||||
if cur_node_dim is not None:
|
if cur_node_dim is not None:
|
||||||
# dim is computed
|
# dim is computed
|
||||||
@ -978,7 +993,7 @@ class FlowTracer(object):
|
|||||||
arg_dim = cur_node_source[cur_node_dim][arg_idx][0]
|
arg_dim = cur_node_source[cur_node_dim][arg_idx][0]
|
||||||
else:
|
else:
|
||||||
arg_dim = None
|
arg_dim = None
|
||||||
|
|
||||||
# get fix dim
|
# get fix dim
|
||||||
arg_fix_dim = []
|
arg_fix_dim = []
|
||||||
if cur_node_dim is not None:
|
if cur_node_dim is not None:
|
||||||
@ -986,44 +1001,52 @@ class FlowTracer(object):
|
|||||||
fix_dim_source = cur_node_source[i]
|
fix_dim_source = cur_node_source[i]
|
||||||
if arg_idx in fix_dim_source:
|
if arg_idx in fix_dim_source:
|
||||||
arg_fix_dim.append(fix_dim_source[arg_idx][0])
|
arg_fix_dim.append(fix_dim_source[arg_idx][0])
|
||||||
|
|
||||||
# if already in node_info, arg dim must be same
|
# if already in node_info, arg dim must be same
|
||||||
if arg_node in all_node_info:
|
if arg_node in all_node_info:
|
||||||
if all_node_info[arg_node] != arg_dim:
|
if all_node_info[arg_node] != arg_dim:
|
||||||
return False
|
return False
|
||||||
all_node_info[arg_node]['fix_dim'] = list(set(all_node_info[arg_node]['fix_dim'] + arg_fix_dim))
|
all_node_info[arg_node]["fix_dim"] = list(
|
||||||
|
set(all_node_info[arg_node]["fix_dim"] + arg_fix_dim)
|
||||||
|
)
|
||||||
# else add it to list
|
# else add it to list
|
||||||
else:
|
else:
|
||||||
all_node_info[arg_node] = {'chunk_dim': arg_dim, 'fix_dim': arg_fix_dim}
|
all_node_info[arg_node] = {"chunk_dim": arg_dim, "fix_dim": arg_fix_dim}
|
||||||
|
|
||||||
next_node_list.append(arg_node)
|
next_node_list.append(arg_node)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def flow_search(self, start_idx, start_dim, end_idx, end_dim, index_tracer: IndexTracer):
|
def flow_search(
|
||||||
|
self, start_idx, start_dim, end_idx, end_dim, index_tracer: IndexTracer
|
||||||
|
):
|
||||||
inputs, outputs = _find_chunk_compute_input_and_output_nodes(
|
inputs, outputs = _find_chunk_compute_input_and_output_nodes(
|
||||||
self.node_list[start_idx : end_idx + 1]
|
self.node_list[start_idx : end_idx + 1]
|
||||||
)
|
)
|
||||||
# only single ouput
|
# only single ouput
|
||||||
if len(outputs) > 1:
|
if len(outputs) > 1:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
cur_node_list = [index_tracer.nodes_list[end_idx]] # start from the last node
|
cur_node_list = [index_tracer.nodes_list[end_idx]] # start from the last node
|
||||||
all_node_info = {cur_node_list[0]: {'chunk_dim': end_dim, 'fix_dim': []}}
|
all_node_info = {cur_node_list[0]: {"chunk_dim": end_dim, "fix_dim": []}}
|
||||||
|
|
||||||
while len(cur_node_list) > 0:
|
while len(cur_node_list) > 0:
|
||||||
next_node_list = []
|
next_node_list = []
|
||||||
|
|
||||||
for cur_node in cur_node_list:
|
for cur_node in cur_node_list:
|
||||||
# get cur node info
|
# get cur node info
|
||||||
cur_node_chunk_dim = all_node_info[cur_node]['chunk_dim']
|
cur_node_chunk_dim = all_node_info[cur_node]["chunk_dim"]
|
||||||
cur_node_fix_dim = all_node_info[cur_node]['fix_dim']
|
cur_node_fix_dim = all_node_info[cur_node]["fix_dim"]
|
||||||
cur_node_idx = _find_idx_by_name(cur_node.name, index_tracer.nodes_list)
|
cur_node_idx = _find_idx_by_name(cur_node.name, index_tracer.nodes_list)
|
||||||
if cur_node_chunk_dim:
|
if cur_node_chunk_dim:
|
||||||
cur_node_compute = index_tracer._find_compute_trace_from_node(cur_node)
|
cur_node_compute = index_tracer._find_compute_trace_from_node(
|
||||||
cur_node_source = index_tracer._find_source_trace_from_node(cur_node)
|
cur_node
|
||||||
|
)
|
||||||
|
cur_node_source = index_tracer._find_source_trace_from_node(
|
||||||
|
cur_node
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
cur_node_compute = cur_node_source = None
|
cur_node_compute = cur_node_source = None
|
||||||
|
|
||||||
# get all valid args
|
# get all valid args
|
||||||
arg_list = []
|
arg_list = []
|
||||||
for arg in cur_node.args:
|
for arg in cur_node.args:
|
||||||
@ -1032,20 +1055,33 @@ class FlowTracer(object):
|
|||||||
if _is_non_compute_node(arg):
|
if _is_non_compute_node(arg):
|
||||||
continue
|
continue
|
||||||
arg_list.append(arg)
|
arg_list.append(arg)
|
||||||
flow_flag = self._assgin_single_node_flow(arg, start_idx, end_idx,
|
flow_flag = self._assgin_single_node_flow(
|
||||||
inputs, index_tracer, cur_node_chunk_dim,
|
arg,
|
||||||
cur_node_compute, cur_node_source, cur_node_fix_dim, all_node_info,
|
start_idx,
|
||||||
next_node_list)
|
end_idx,
|
||||||
|
inputs,
|
||||||
|
index_tracer,
|
||||||
|
cur_node_chunk_dim,
|
||||||
|
cur_node_compute,
|
||||||
|
cur_node_source,
|
||||||
|
cur_node_fix_dim,
|
||||||
|
all_node_info,
|
||||||
|
next_node_list,
|
||||||
|
)
|
||||||
if flow_flag == False:
|
if flow_flag == False:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if len(arg_list) == 2:
|
if len(arg_list) == 2:
|
||||||
if any(i in cur_node.name for i in ["add", "mul"]):
|
if any(i in cur_node.name for i in ["add", "mul"]):
|
||||||
for arg in arg_list:
|
for arg in arg_list:
|
||||||
if not (start_idx <= _find_idx_by_name(arg.name, index_tracer.nodes_list) < end_idx):
|
if not (
|
||||||
|
start_idx
|
||||||
|
<= _find_idx_by_name(arg.name, index_tracer.nodes_list)
|
||||||
|
< end_idx
|
||||||
|
):
|
||||||
continue
|
continue
|
||||||
arg_chunk_dim = all_node_info[arg]['chunk_dim']
|
arg_chunk_dim = all_node_info[arg]["chunk_dim"]
|
||||||
arg_fix_dim = all_node_info[arg]['fix_dim']
|
arg_fix_dim = all_node_info[arg]["fix_dim"]
|
||||||
arg_shape = _get_node_shape(arg)
|
arg_shape = _get_node_shape(arg)
|
||||||
# add all dim as fix dim except chunk dim
|
# add all dim as fix dim except chunk dim
|
||||||
for i, shape in enumerate(arg_shape):
|
for i, shape in enumerate(arg_shape):
|
||||||
@ -1061,7 +1097,7 @@ class FlowTracer(object):
|
|||||||
else:
|
else:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
cur_node_list = next_node_list
|
cur_node_list = next_node_list
|
||||||
|
|
||||||
inputs_dim = []
|
inputs_dim = []
|
||||||
remove_inputs = []
|
remove_inputs = []
|
||||||
for input_node in inputs:
|
for input_node in inputs:
|
||||||
@ -1071,7 +1107,7 @@ class FlowTracer(object):
|
|||||||
continue
|
continue
|
||||||
user_idx = _find_idx_by_name(user.name, self.node_list)
|
user_idx = _find_idx_by_name(user.name, self.node_list)
|
||||||
if start_idx <= user_idx <= end_idx:
|
if start_idx <= user_idx <= end_idx:
|
||||||
chunk_dim = all_node_info[user]['chunk_dim']
|
chunk_dim = all_node_info[user]["chunk_dim"]
|
||||||
if chunk_dim is not None:
|
if chunk_dim is not None:
|
||||||
input_dict[user_idx] = chunk_dim
|
input_dict[user_idx] = chunk_dim
|
||||||
if len(input_dict) == 0:
|
if len(input_dict) == 0:
|
||||||
@ -1081,7 +1117,7 @@ class FlowTracer(object):
|
|||||||
for i in remove_inputs:
|
for i in remove_inputs:
|
||||||
if i in inputs:
|
if i in inputs:
|
||||||
inputs.remove(i)
|
inputs.remove(i)
|
||||||
|
|
||||||
chunk_info = {
|
chunk_info = {
|
||||||
"region": (start_idx, end_idx),
|
"region": (start_idx, end_idx),
|
||||||
"inputs": inputs,
|
"inputs": inputs,
|
||||||
@ -1091,7 +1127,7 @@ class FlowTracer(object):
|
|||||||
"outputs_dim": end_dim,
|
"outputs_dim": end_dim,
|
||||||
"args": {},
|
"args": {},
|
||||||
}
|
}
|
||||||
|
|
||||||
# we need to log input nodes to avoid deleteing them in the loop
|
# we need to log input nodes to avoid deleteing them in the loop
|
||||||
non_chunk_inputs = _find_chunk_all_input_nodes(
|
non_chunk_inputs = _find_chunk_all_input_nodes(
|
||||||
self.node_list[start_idx : end_idx + 1]
|
self.node_list[start_idx : end_idx + 1]
|
||||||
@ -1129,7 +1165,7 @@ class MemoryEstimator(object):
|
|||||||
|
|
||||||
def _add_active_node(self, n, active_list):
|
def _add_active_node(self, n, active_list):
|
||||||
new_active = self._get_output_node(n)[1]
|
new_active = self._get_output_node(n)[1]
|
||||||
if n.op == 'placeholder':
|
if n.op == "placeholder":
|
||||||
new_active.append(n.name)
|
new_active.append(n.name)
|
||||||
for i in new_active:
|
for i in new_active:
|
||||||
if i not in active_list:
|
if i not in active_list:
|
||||||
@ -1168,12 +1204,16 @@ class MemoryEstimator(object):
|
|||||||
for i in delete_node:
|
for i in delete_node:
|
||||||
if i in active_list:
|
if i in active_list:
|
||||||
active_list.remove(i)
|
active_list.remove(i)
|
||||||
|
|
||||||
def _get_chunk_inputs_size(self, chunk_inputs, chunk_inputs_non_chunk, node_list, chunk_end_idx):
|
def _get_chunk_inputs_size(
|
||||||
|
self, chunk_inputs, chunk_inputs_non_chunk, node_list, chunk_end_idx
|
||||||
|
):
|
||||||
nodes_to_delete = []
|
nodes_to_delete = []
|
||||||
for chunk_input in chunk_inputs + chunk_inputs_non_chunk:
|
for chunk_input in chunk_inputs + chunk_inputs_non_chunk:
|
||||||
chunk_input_users = chunk_input.users.keys()
|
chunk_input_users = chunk_input.users.keys()
|
||||||
chunk_input_users_idx = [_find_idx_by_name(i.name, node_list) for i in chunk_input_users]
|
chunk_input_users_idx = [
|
||||||
|
_find_idx_by_name(i.name, node_list) for i in chunk_input_users
|
||||||
|
]
|
||||||
if all(i <= chunk_end_idx for i in chunk_input_users_idx):
|
if all(i <= chunk_end_idx for i in chunk_input_users_idx):
|
||||||
if chunk_input not in nodes_to_delete:
|
if chunk_input not in nodes_to_delete:
|
||||||
nodes_to_delete.append(chunk_input)
|
nodes_to_delete.append(chunk_input)
|
||||||
@ -1226,7 +1266,9 @@ class MemoryEstimator(object):
|
|||||||
for (input_node, input_node_dim) in zip(chunk_inputs, chunk_inputs_dim):
|
for (input_node, input_node_dim) in zip(chunk_inputs, chunk_inputs_dim):
|
||||||
for k, v in input_node_dim.items():
|
for k, v in input_node_dim.items():
|
||||||
# TODO: inherit dim should be list too, int now
|
# TODO: inherit dim should be list too, int now
|
||||||
inherit_dim = self.index_tracer._find_inherit_dim(input_node, v, self.index_tracer.nodes_list[k])
|
inherit_dim = self.index_tracer._find_inherit_dim(
|
||||||
|
input_node, v, self.index_tracer.nodes_list[k]
|
||||||
|
)
|
||||||
if k == _find_idx_by_name(node.name, self.index_tracer.nodes_list):
|
if k == _find_idx_by_name(node.name, self.index_tracer.nodes_list):
|
||||||
chunk_ratio = float(chunk_size) / node_shape[inherit_dim]
|
chunk_ratio = float(chunk_size) / node_shape[inherit_dim]
|
||||||
return chunk_ratio
|
return chunk_ratio
|
||||||
@ -1234,7 +1276,7 @@ class MemoryEstimator(object):
|
|||||||
if k in source and inherit_dim in source[k]:
|
if k in source and inherit_dim in source[k]:
|
||||||
chunk_ratio = float(chunk_size) / node_shape[dim]
|
chunk_ratio = float(chunk_size) / node_shape[dim]
|
||||||
return chunk_ratio
|
return chunk_ratio
|
||||||
return 1.
|
return 1.0
|
||||||
|
|
||||||
def _get_chunk_delete_node_size(
|
def _get_chunk_delete_node_size(
|
||||||
self, user, user_to_last_uses, chunk_ratio, chunk_inputs_names
|
self, user, user_to_last_uses, chunk_ratio, chunk_inputs_names
|
||||||
@ -1295,7 +1337,7 @@ class MemoryEstimator(object):
|
|||||||
chunk_ratio = 1 # use it to estimate chunk mem
|
chunk_ratio = 1 # use it to estimate chunk mem
|
||||||
chunk_size = 1
|
chunk_size = 1
|
||||||
chunk_inputs_names = []
|
chunk_inputs_names = []
|
||||||
|
|
||||||
if use_chunk:
|
if use_chunk:
|
||||||
chunk_regions = [i["region"] for i in chunk_infos]
|
chunk_regions = [i["region"] for i in chunk_infos]
|
||||||
chunk_starts = [i[0] for i in chunk_regions]
|
chunk_starts = [i[0] for i in chunk_regions]
|
||||||
@ -1313,12 +1355,17 @@ class MemoryEstimator(object):
|
|||||||
if use_chunk and idx in chunk_starts:
|
if use_chunk and idx in chunk_starts:
|
||||||
chunk_within = True
|
chunk_within = True
|
||||||
chunk_region_idx = chunk_starts.index(idx)
|
chunk_region_idx = chunk_starts.index(idx)
|
||||||
act_memory += self._get_output_node_size(chunk_outputs[chunk_region_idx]) / (1024**2)
|
act_memory += self._get_output_node_size(
|
||||||
|
chunk_outputs[chunk_region_idx]
|
||||||
|
) / (1024**2)
|
||||||
|
|
||||||
# determine chunk ratio for current node
|
# determine chunk ratio for current node
|
||||||
if chunk_within:
|
if chunk_within:
|
||||||
chunk_ratio = self._get_chunk_ratio(
|
chunk_ratio = self._get_chunk_ratio(
|
||||||
node, chunk_inputs[chunk_region_idx], chunk_inputs_dim[chunk_region_idx], chunk_size
|
node,
|
||||||
|
chunk_inputs[chunk_region_idx],
|
||||||
|
chunk_inputs_dim[chunk_region_idx],
|
||||||
|
chunk_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
# if node is placeholder, just add the size of the node
|
# if node is placeholder, just add the size of the node
|
||||||
@ -1353,18 +1400,18 @@ class MemoryEstimator(object):
|
|||||||
/ (1024**2)
|
/ (1024**2)
|
||||||
)
|
)
|
||||||
# delete unused vars not in chunk_input_list
|
# delete unused vars not in chunk_input_list
|
||||||
# we can't delete input nodes until chunk ends
|
# we can't delete input nodes until chunk ends
|
||||||
if chunk_within:
|
if chunk_within:
|
||||||
act_memory -= self._get_chunk_delete_node_size(
|
act_memory -= self._get_chunk_delete_node_size(
|
||||||
node,
|
node,
|
||||||
user_to_last_uses_no_free_var,
|
user_to_last_uses_no_free_var,
|
||||||
chunk_ratio,
|
chunk_ratio,
|
||||||
chunk_inputs_names
|
chunk_inputs_names,
|
||||||
) / (1024**2)
|
) / (1024**2)
|
||||||
else:
|
else:
|
||||||
act_memory -= (self._get_delete_node_size(
|
act_memory -= self._get_delete_node_size(
|
||||||
node, user_to_last_uses_no_free_var, chunk_inputs_names
|
node, user_to_last_uses_no_free_var, chunk_inputs_names
|
||||||
) / (1024**2))
|
) / (1024**2)
|
||||||
|
|
||||||
# log active node, only effective without chunk
|
# log active node, only effective without chunk
|
||||||
self._add_active_node(node, active_node_list)
|
self._add_active_node(node, active_node_list)
|
||||||
@ -1376,11 +1423,11 @@ class MemoryEstimator(object):
|
|||||||
self._get_output_node_size(node) * chunk_ratio / (1024**2)
|
self._get_output_node_size(node) * chunk_ratio / (1024**2)
|
||||||
)
|
)
|
||||||
act_memory -= self._get_chunk_inputs_size(
|
act_memory -= self._get_chunk_inputs_size(
|
||||||
chunk_inputs[chunk_region_idx],
|
chunk_inputs[chunk_region_idx],
|
||||||
chunk_inputs_non_chunk[chunk_region_idx],
|
chunk_inputs_non_chunk[chunk_region_idx],
|
||||||
node_list,
|
node_list,
|
||||||
chunk_regions[chunk_region_idx][1]
|
chunk_regions[chunk_region_idx][1],
|
||||||
) / (1024**2)
|
) / (1024**2)
|
||||||
chunk_within = False
|
chunk_within = False
|
||||||
chunk_ratio = 1
|
chunk_ratio = 1
|
||||||
chunk_region_idx = None
|
chunk_region_idx = None
|
||||||
@ -1436,7 +1483,7 @@ class ChunkRegionSearch(object):
|
|||||||
active_node_num = [len(i) for i in active_node]
|
active_node_num = [len(i) for i in active_node]
|
||||||
min_active_node_num = min(active_node_num[free_var_num:])
|
min_active_node_num = min(active_node_num[free_var_num:])
|
||||||
threshold = max(free_var_num, min_active_node_num)
|
threshold = max(free_var_num, min_active_node_num)
|
||||||
|
|
||||||
# from peak_node to free_var
|
# from peak_node to free_var
|
||||||
inside_flag = False
|
inside_flag = False
|
||||||
chunk_region_start = free_var_num
|
chunk_region_start = free_var_num
|
||||||
@ -1494,7 +1541,12 @@ class ChunkRegionSearch(object):
|
|||||||
continue
|
continue
|
||||||
for start_node, start_trace in start_traces.items():
|
for start_node, start_trace in start_traces.items():
|
||||||
for start_dim, start_trace_idx in enumerate(start_trace["idx"]):
|
for start_dim, start_trace_idx in enumerate(start_trace["idx"]):
|
||||||
if start_idx == 199 and end_idx == 229 and start_dim == 2 and end_dim == 2:
|
if (
|
||||||
|
start_idx == 199
|
||||||
|
and end_idx == 229
|
||||||
|
and start_dim == 2
|
||||||
|
and end_dim == 2
|
||||||
|
):
|
||||||
print(1)
|
print(1)
|
||||||
self.flow_tracer.flow_search(
|
self.flow_tracer.flow_search(
|
||||||
start_idx, start_dim, end_idx, end_dim, self.index_tracer
|
start_idx, start_dim, end_idx, end_dim, self.index_tracer
|
||||||
@ -1576,7 +1628,7 @@ class ChunkRegionSearch(object):
|
|||||||
max_region_range = 0
|
max_region_range = 0
|
||||||
best_region = None
|
best_region = None
|
||||||
return best_region
|
return best_region
|
||||||
|
|
||||||
def _is_legal_region(self, cur_chunk_info, chunk_infos):
|
def _is_legal_region(self, cur_chunk_info, chunk_infos):
|
||||||
(chunk_region_start, chunk_region_end) = cur_chunk_info["region"]
|
(chunk_region_start, chunk_region_end) = cur_chunk_info["region"]
|
||||||
if cur_chunk_info in chunk_infos:
|
if cur_chunk_info in chunk_infos:
|
||||||
@ -1585,11 +1637,13 @@ class ChunkRegionSearch(object):
|
|||||||
return False
|
return False
|
||||||
for i in chunk_infos:
|
for i in chunk_infos:
|
||||||
region = i["region"]
|
region = i["region"]
|
||||||
if not ((chunk_region_start > region[1] and chunk_region_end > region[1])
|
if not (
|
||||||
or (chunk_region_start < region[0] and chunk_region_end < region[0])):
|
(chunk_region_start > region[1] and chunk_region_end > region[1])
|
||||||
|
or (chunk_region_start < region[0] and chunk_region_end < region[0])
|
||||||
|
):
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _step_search(self, mem_peak, active_node, chunk_regions):
|
def _step_search(self, mem_peak, active_node, chunk_regions):
|
||||||
peak_node = self._find_peak_node(mem_peak)
|
peak_node = self._find_peak_node(mem_peak)
|
||||||
max_chunk_region = self._search_max_chunk_region(
|
max_chunk_region = self._search_max_chunk_region(
|
||||||
@ -1600,7 +1654,9 @@ class ChunkRegionSearch(object):
|
|||||||
possible_chunk_regions = self._search_possible_chunk_regions(
|
possible_chunk_regions = self._search_possible_chunk_regions(
|
||||||
max_chunk_region, peak_node
|
max_chunk_region, peak_node
|
||||||
)
|
)
|
||||||
best_chunk_region = self._search_best_chunk_region(possible_chunk_regions, chunk_regions)
|
best_chunk_region = self._search_best_chunk_region(
|
||||||
|
possible_chunk_regions, chunk_regions
|
||||||
|
)
|
||||||
return best_chunk_region
|
return best_chunk_region
|
||||||
|
|
||||||
def _stop_search(self, init_mem_peak, mem_peak):
|
def _stop_search(self, init_mem_peak, mem_peak):
|
||||||
@ -1667,7 +1723,11 @@ def _gen_loop_end(
|
|||||||
chunk_slice = _gen_chunk_slice_dim(
|
chunk_slice = _gen_chunk_slice_dim(
|
||||||
chunk_outputs_dim, "chunk_idx", chunk_output_shape
|
chunk_outputs_dim, "chunk_idx", chunk_output_shape
|
||||||
)
|
)
|
||||||
context = " chunk_result%s = %s; %s = None\n" % (chunk_slice, chunk_outputs_name, chunk_outputs_name)
|
context = " chunk_result%s = %s; %s = None\n" % (
|
||||||
|
chunk_slice,
|
||||||
|
chunk_outputs_name,
|
||||||
|
chunk_outputs_name,
|
||||||
|
)
|
||||||
context += (
|
context += (
|
||||||
chunk_outputs_name + " = chunk_result; chunk_result = None; chunk_size = None"
|
chunk_outputs_name + " = chunk_result; chunk_result = None; chunk_size = None"
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user