diff --git a/chunk_codegen.py b/chunk_codegen.py index 7c334c617..6f8ff2b23 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -1406,9 +1406,9 @@ class MemoryEstimator(object): # self._print_mem_log(act_memory_peak_log, node_list, "peak") # self._print_mem_log(act_memory_after_node_log, node_list, "after") self._print_compute_op_mem_log(act_memory_peak_log, node_list, "peak") - self._print_compute_op_mem_log( - act_memory_after_node_log, node_list, "after" - ) + # self._print_compute_op_mem_log( + # act_memory_after_node_log, node_list, "after" + # ) # param_memory = parameter_size(gm) # all_memory = act_memory + param_memory @@ -1465,6 +1465,9 @@ class ChunkSelector(object): if i in possible_chunk_regions: possible_chunk_regions.remove(i) + if len(possible_chunk_regions) == 0: + return None + # get mem for chunk region regions_dict = [] for region in possible_chunk_regions: @@ -1492,7 +1495,7 @@ class ChunkSelector(object): ) # no region found if len(regions_dict) == 0: - return None + raise RuntimeError("Search failed. Try a larger memory threshold.") # select the min chunk len chunk_len = [i["chunk_len"] for i in regions_dict] @@ -1995,6 +1998,14 @@ def emit_code_with_chunk( body[-1] = _replace_name( body[-1], input_node.name, input_node.name + chunk_slice ) + # ones like + if "ones_like" in node.name: + chunk_slice = _gen_chunk_slice_dim( + chunk_search[region_idx]["node_chunk_dim"][chunk_region_search.index_tracer.node_list[node_idx]]["chunk_dim"], "chunk_idx", _get_node_shape(node) + ) + body[-1] = _replace_name( + body[-1], node.args[0].name, node.args[0].name + chunk_slice + ) body[-1] = _replace_reshape_size( body[-1], node.name, chunk_search[region_idx]["reshape_size"] )