diff --git a/chunk_codegen.py b/chunk_codegen.py index 470768855..3cd10350e 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -988,7 +988,9 @@ class IndexTracer(object): def _reassgin_reshape_size(self, chunk_info): chunk_region = chunk_info["region"] reshape_size = {} - chunk_shape = _get_node_shape(chunk_info["outputs"][0])[chunk_info["outputs_dim"]] + chunk_shape = _get_node_shape(chunk_info["outputs"][0])[ + chunk_info["outputs_dim"] + ] for node in self.node_list[chunk_region[0] : chunk_region[1] + 1]: if any(i in node.name for i in ["reshape", "view"]): reshape_args = node.args[1:] @@ -999,7 +1001,9 @@ class IndexTracer(object): if reshape_arg_dim in reshape_log["dim_to"]: continue if reshape_arg_dim == chunk_dim: - reshape_size[node.name][reshape_arg.name] = "min(chunk_size, %d - chunk_idx)" % chunk_shape + reshape_size[node.name][reshape_arg.name] = ( + "min(chunk_size, %d - chunk_idx)" % chunk_shape + ) chunk_info["reshape_size"] = reshape_size return chunk_info @@ -1498,7 +1502,7 @@ class ChunkSelector(object): else: gap = 1 while r >= l + gap: - mid = int(l + (r - l)/2) + mid = int(l + (r - l) / 2) chunk_info["chunk_size"] = mid cur_chunk_infos = chunk_infos + [chunk_info] cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem( @@ -1938,7 +1942,7 @@ def emit_code_with_chunk( chunk_inputs[region_idx], chunk_outputs[region_idx], chunk_outputs_dim[region_idx], - chunk_size=chunk_search[region_idx]["chunk_size"] + chunk_search[region_idx]["chunk_size"], ) )