diff --git a/chunk_codegen.py b/chunk_codegen.py index 330f3dec6..1255852d7 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -982,24 +982,24 @@ class IndexTracer(object): # reassgin reshape size, some size may have changed due to chunk chunk_info = self._reassgin_reshape_size(chunk_info) - + return chunk_info - + def _reassgin_reshape_size(self, chunk_info): - chunk_region = chunk_info['region'] + chunk_region = chunk_info["region"] reshape_size = {} - for node in self.node_list[chunk_region[0]: chunk_region[1] + 1]: - if any(i in node.name for i in ['reshape', 'view']): + for node in self.node_list[chunk_region[0] : chunk_region[1] + 1]: + if any(i in node.name for i in ["reshape", "view"]): reshape_args = node.args[1:] reshape_log = self.idx_view_list[node] - chunk_dim = chunk_info['node_chunk_dim'][node]['chunk_dim'] + chunk_dim = chunk_info["node_chunk_dim"][node]["chunk_dim"] reshape_size[node.name] = {} for reshape_arg_dim, reshape_arg in enumerate(reshape_args): - if reshape_arg_dim in reshape_log['dim_to']: + if reshape_arg_dim in reshape_log["dim_to"]: continue if reshape_arg_dim == chunk_dim: reshape_size[node.name][reshape_arg.name] = "chunk_size" - chunk_info['reshape_size'] = reshape_size + chunk_info["reshape_size"] = reshape_size return chunk_info def _get_reorder_map(self, chunk_info): @@ -1213,7 +1213,7 @@ class MemoryEstimator(object): if node not in chunk_node_dim: return 1.0 node_shape = _get_node_shape(node) - chunk_dim = chunk_node_dim[node]['chunk_dim'] + chunk_dim = chunk_node_dim[node]["chunk_dim"] if chunk_dim is None: return 1.0 else: @@ -1381,7 +1381,9 @@ class MemoryEstimator(object): # self._print_mem_log(act_memory_peak_log, node_list, "peak") # self._print_mem_log(act_memory_after_node_log, node_list, "after") self._print_compute_op_mem_log(act_memory_peak_log, node_list, "peak") - self._print_compute_op_mem_log(act_memory_after_node_log, node_list, "after") + self._print_compute_op_mem_log( + act_memory_after_node_log, node_list, "after" + ) # param_memory = parameter_size(gm) # all_memory = act_memory + param_memory @@ -1389,30 +1391,41 @@ class MemoryEstimator(object): class ChunkSelector(object): - def __init__(self, index_tracer: IndexTracer, memory_estimator: MemoryEstimator, stratge): + def __init__( + self, index_tracer: IndexTracer, memory_estimator: MemoryEstimator, stratge + ): self.index_tracer = index_tracer self.memory_estimator = memory_estimator - assert stratge in ['min_memory', 'fit_memory'] + assert stratge in ["min_memory", "fit_memory"] self.stratge = stratge self.max_memory = 600 # MB - - def _select_best_chunk_region(self, possible_chunk_regions, - chunk_infos, peak_node, max_chunk_region, mem_peak): - if self.stratge == 'min_memory': - best_region = self._select_min_memory_chunk_region(possible_chunk_regions, chunk_infos) - elif self.stratge == 'fit_memory': + + def _select_best_chunk_region( + self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak + ): + if self.stratge == "min_memory": + best_region = self._select_min_memory_chunk_region( + possible_chunk_regions, chunk_infos + ) + elif self.stratge == "fit_memory": best_region = self._select_fit_memory_chunk_region( - possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak) + possible_chunk_regions, + chunk_infos, + peak_node, + max_chunk_region, + mem_peak, + ) else: raise RuntimeError() return best_region - - def _select_fit_memory_chunk_region(self, possible_chunk_regions, - chunk_infos, peak_node, max_chunk_region, mem_peak): + + def _select_fit_memory_chunk_region( + self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak + ): # stop chunk if max memory satisfy memory limit if max(mem_peak) < self.max_memory: return None - + # remove illegal regions illegal_regions = [] for i in possible_chunk_regions: @@ -1421,38 +1434,45 @@ class ChunkSelector(object): for i in illegal_regions: if i in possible_chunk_regions: possible_chunk_regions.remove(i) - + # get mem for chunk region regions_dict = [] for region in possible_chunk_regions: cur_chunk_infos = chunk_infos + [region] cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem( - self.index_tracer.node_list, cur_chunk_infos)[0] - cur_chunk_region_peak = cur_mem_peak[max_chunk_region[0]: max_chunk_region[1] + 1] + self.index_tracer.node_list, cur_chunk_infos + )[0] + cur_chunk_region_peak = cur_mem_peak[ + max_chunk_region[0] : max_chunk_region[1] + 1 + ] cur_chunk_region_max_peak = max(cur_chunk_region_peak) if cur_chunk_region_max_peak < self.max_memory: - regions_dict.append({ - "chunk_info": region, - "chunk_max_mem": cur_chunk_region_max_peak, - "chunk_len": self._get_compute_node_num(region['region'][0], region['region'][1]), - }) + regions_dict.append( + { + "chunk_info": region, + "chunk_max_mem": cur_chunk_region_max_peak, + "chunk_len": self._get_compute_node_num( + region["region"][0], region["region"][1] + ), + } + ) # no region found if len(regions_dict) == 0: return None - + # select the min chunk len chunk_len = [i["chunk_len"] for i in regions_dict] best_region_idx = chunk_len.index(min(chunk_len)) best_region = regions_dict[best_region_idx]["chunk_info"] return best_region - + def _get_compute_node_num(self, start, end): count = 0 - for i in self.index_tracer.node_list[start: end+1]: + for i in self.index_tracer.node_list[start : end + 1]: if _is_non_compute_node(i): count += 1 return count - + def _select_min_memory_chunk_region(self, possible_chunk_regions, chunk_infos): max_region_range = 0 best_region = None @@ -1490,7 +1510,9 @@ class ChunkRegionSearch(object): self.index_tracer = IndexTracer(list(gm.graph.nodes)) self.index_tracer.trace_index() self.memory_estimator = MemoryEstimator(self.index_tracer) - self.chunk_selector = ChunkSelector(self.index_tracer, self.memory_estimator, stratge="fit_memory") + self.chunk_selector = ChunkSelector( + self.index_tracer, self.memory_estimator, stratge="fit_memory" + ) def _find_peak_node(self, mem_peak): max_value = max(mem_peak) @@ -1808,10 +1830,11 @@ def _replace_name(context, name_from, name_to): def _replace_reshape_size(context, node_name, reshape_size_dict): if node_name not in reshape_size_dict: return context - for size_name, size_value in reshape_size_dict[node_name].items(): + for size_name, size_value in reshape_size_dict[node_name].items(): context = context.replace(size_name, size_value) return context + def emit_code_with_chunk( body, ckpt_func, @@ -1883,7 +1906,9 @@ def emit_code_with_chunk( body[-1] = _replace_name( body[-1], input_node.name, input_node.name + chunk_slice ) - body[-1] = _replace_reshape_size(body[-1], node.name, chunk_search[region_idx]['reshape_size']) + body[-1] = _replace_reshape_size( + body[-1], node.name, chunk_search[region_idx]["reshape_size"] + ) body[-1] = " " + body[-1] delete_unused_value_func(node, body, chunk_inputs_names) else: