code style

This commit is contained in:
oahzxl 2022-12-27 09:48:59 +08:00
parent 8f5a0edfab
commit 378a49dc6c

View File

@ -986,20 +986,20 @@ class IndexTracer(object):
return chunk_info return chunk_info
def _reassgin_reshape_size(self, chunk_info): def _reassgin_reshape_size(self, chunk_info):
chunk_region = chunk_info['region'] chunk_region = chunk_info["region"]
reshape_size = {} reshape_size = {}
for node in self.node_list[chunk_region[0]: chunk_region[1] + 1]: for node in self.node_list[chunk_region[0] : chunk_region[1] + 1]:
if any(i in node.name for i in ['reshape', 'view']): if any(i in node.name for i in ["reshape", "view"]):
reshape_args = node.args[1:] reshape_args = node.args[1:]
reshape_log = self.idx_view_list[node] reshape_log = self.idx_view_list[node]
chunk_dim = chunk_info['node_chunk_dim'][node]['chunk_dim'] chunk_dim = chunk_info["node_chunk_dim"][node]["chunk_dim"]
reshape_size[node.name] = {} reshape_size[node.name] = {}
for reshape_arg_dim, reshape_arg in enumerate(reshape_args): for reshape_arg_dim, reshape_arg in enumerate(reshape_args):
if reshape_arg_dim in reshape_log['dim_to']: if reshape_arg_dim in reshape_log["dim_to"]:
continue continue
if reshape_arg_dim == chunk_dim: if reshape_arg_dim == chunk_dim:
reshape_size[node.name][reshape_arg.name] = "chunk_size" reshape_size[node.name][reshape_arg.name] = "chunk_size"
chunk_info['reshape_size'] = reshape_size chunk_info["reshape_size"] = reshape_size
return chunk_info return chunk_info
def _get_reorder_map(self, chunk_info): def _get_reorder_map(self, chunk_info):
@ -1213,7 +1213,7 @@ class MemoryEstimator(object):
if node not in chunk_node_dim: if node not in chunk_node_dim:
return 1.0 return 1.0
node_shape = _get_node_shape(node) node_shape = _get_node_shape(node)
chunk_dim = chunk_node_dim[node]['chunk_dim'] chunk_dim = chunk_node_dim[node]["chunk_dim"]
if chunk_dim is None: if chunk_dim is None:
return 1.0 return 1.0
else: else:
@ -1381,7 +1381,9 @@ class MemoryEstimator(object):
# self._print_mem_log(act_memory_peak_log, node_list, "peak") # self._print_mem_log(act_memory_peak_log, node_list, "peak")
# self._print_mem_log(act_memory_after_node_log, node_list, "after") # self._print_mem_log(act_memory_after_node_log, node_list, "after")
self._print_compute_op_mem_log(act_memory_peak_log, node_list, "peak") self._print_compute_op_mem_log(act_memory_peak_log, node_list, "peak")
self._print_compute_op_mem_log(act_memory_after_node_log, node_list, "after") self._print_compute_op_mem_log(
act_memory_after_node_log, node_list, "after"
)
# param_memory = parameter_size(gm) # param_memory = parameter_size(gm)
# all_memory = act_memory + param_memory # all_memory = act_memory + param_memory
@ -1389,26 +1391,37 @@ class MemoryEstimator(object):
class ChunkSelector(object): class ChunkSelector(object):
def __init__(self, index_tracer: IndexTracer, memory_estimator: MemoryEstimator, stratge): def __init__(
self, index_tracer: IndexTracer, memory_estimator: MemoryEstimator, stratge
):
self.index_tracer = index_tracer self.index_tracer = index_tracer
self.memory_estimator = memory_estimator self.memory_estimator = memory_estimator
assert stratge in ['min_memory', 'fit_memory'] assert stratge in ["min_memory", "fit_memory"]
self.stratge = stratge self.stratge = stratge
self.max_memory = 600 # MB self.max_memory = 600 # MB
def _select_best_chunk_region(self, possible_chunk_regions, def _select_best_chunk_region(
chunk_infos, peak_node, max_chunk_region, mem_peak): self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak
if self.stratge == 'min_memory': ):
best_region = self._select_min_memory_chunk_region(possible_chunk_regions, chunk_infos) if self.stratge == "min_memory":
elif self.stratge == 'fit_memory': best_region = self._select_min_memory_chunk_region(
possible_chunk_regions, chunk_infos
)
elif self.stratge == "fit_memory":
best_region = self._select_fit_memory_chunk_region( best_region = self._select_fit_memory_chunk_region(
possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak) possible_chunk_regions,
chunk_infos,
peak_node,
max_chunk_region,
mem_peak,
)
else: else:
raise RuntimeError() raise RuntimeError()
return best_region return best_region
def _select_fit_memory_chunk_region(self, possible_chunk_regions, def _select_fit_memory_chunk_region(
chunk_infos, peak_node, max_chunk_region, mem_peak): self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak
):
# stop chunk if max memory satisfy memory limit # stop chunk if max memory satisfy memory limit
if max(mem_peak) < self.max_memory: if max(mem_peak) < self.max_memory:
return None return None
@ -1427,15 +1440,22 @@ class ChunkSelector(object):
for region in possible_chunk_regions: for region in possible_chunk_regions:
cur_chunk_infos = chunk_infos + [region] cur_chunk_infos = chunk_infos + [region]
cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem( cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem(
self.index_tracer.node_list, cur_chunk_infos)[0] self.index_tracer.node_list, cur_chunk_infos
cur_chunk_region_peak = cur_mem_peak[max_chunk_region[0]: max_chunk_region[1] + 1] )[0]
cur_chunk_region_peak = cur_mem_peak[
max_chunk_region[0] : max_chunk_region[1] + 1
]
cur_chunk_region_max_peak = max(cur_chunk_region_peak) cur_chunk_region_max_peak = max(cur_chunk_region_peak)
if cur_chunk_region_max_peak < self.max_memory: if cur_chunk_region_max_peak < self.max_memory:
regions_dict.append({ regions_dict.append(
{
"chunk_info": region, "chunk_info": region,
"chunk_max_mem": cur_chunk_region_max_peak, "chunk_max_mem": cur_chunk_region_max_peak,
"chunk_len": self._get_compute_node_num(region['region'][0], region['region'][1]), "chunk_len": self._get_compute_node_num(
}) region["region"][0], region["region"][1]
),
}
)
# no region found # no region found
if len(regions_dict) == 0: if len(regions_dict) == 0:
return None return None
@ -1448,7 +1468,7 @@ class ChunkSelector(object):
def _get_compute_node_num(self, start, end): def _get_compute_node_num(self, start, end):
count = 0 count = 0
for i in self.index_tracer.node_list[start: end+1]: for i in self.index_tracer.node_list[start : end + 1]:
if _is_non_compute_node(i): if _is_non_compute_node(i):
count += 1 count += 1
return count return count
@ -1490,7 +1510,9 @@ class ChunkRegionSearch(object):
self.index_tracer = IndexTracer(list(gm.graph.nodes)) self.index_tracer = IndexTracer(list(gm.graph.nodes))
self.index_tracer.trace_index() self.index_tracer.trace_index()
self.memory_estimator = MemoryEstimator(self.index_tracer) self.memory_estimator = MemoryEstimator(self.index_tracer)
self.chunk_selector = ChunkSelector(self.index_tracer, self.memory_estimator, stratge="fit_memory") self.chunk_selector = ChunkSelector(
self.index_tracer, self.memory_estimator, stratge="fit_memory"
)
def _find_peak_node(self, mem_peak): def _find_peak_node(self, mem_peak):
max_value = max(mem_peak) max_value = max(mem_peak)
@ -1812,6 +1834,7 @@ def _replace_reshape_size(context, node_name, reshape_size_dict):
context = context.replace(size_name, size_value) context = context.replace(size_name, size_value)
return context return context
def emit_code_with_chunk( def emit_code_with_chunk(
body, body,
ckpt_func, ckpt_func,
@ -1883,7 +1906,9 @@ def emit_code_with_chunk(
body[-1] = _replace_name( body[-1] = _replace_name(
body[-1], input_node.name, input_node.name + chunk_slice body[-1], input_node.name, input_node.name + chunk_slice
) )
body[-1] = _replace_reshape_size(body[-1], node.name, chunk_search[region_idx]['reshape_size']) body[-1] = _replace_reshape_size(
body[-1], node.name, chunk_search[region_idx]["reshape_size"]
)
body[-1] = " " + body[-1] body[-1] = " " + body[-1]
delete_unused_value_func(node, body, chunk_inputs_names) delete_unused_value_func(node, body, chunk_inputs_names)
else: else: