diff --git a/chunk_codegen.py b/chunk_codegen.py index a8b970116..e3a7643d7 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -1198,7 +1198,7 @@ class FlowTracer(object): chunk_node_list.remove(n) non_chunk_inputs = _find_chunk_all_input_nodes(chunk_node_list) for i in non_chunk_inputs: - if i not in chunk_info["inputs"] and i not in prepose_nodes: + if i not in chunk_info["inputs"]: chunk_info["inputs_non_chunk"].append(i) return chunk_info @@ -1425,6 +1425,7 @@ class MemoryEstimator(object): ) / (1024**2) # determine chunk ratio for current node + # TODO: adapt to prepose node memory if chunk_within: chunk_ratio = self._get_chunk_ratio( node, @@ -1602,7 +1603,6 @@ class ChunkRegionSearch(object): chunk_infos = [] for end_dim, end_trace_idx in enumerate(end_trace["idx"]): if len(start_traces) > 1: - # TODO: implement multi input chunk continue for start_node, start_trace in start_traces.items(): for start_dim, start_trace_idx in enumerate(start_trace["idx"]): @@ -1831,7 +1831,6 @@ def _find_chunk_compute_input_and_output_nodes(nodes: List[Node]): # if a node has a user node which is not in the node list # we treat that user node as the node receiving the current node output - # TODO: it is unsafe to remove non compute node here for node in nodes: for output_node in node.users.keys(): if ( @@ -1900,6 +1899,8 @@ def emit_code_with_chunk( chunk_outputs = [i["outputs"][0] for i in chunk_search] chunk_outputs_dim = [i["outputs_dim"] for i in chunk_search] + + chunk_prepose_nodes = [i["args"]["prepose_nodes"] for i in chunk_search] node_idx = 0 region_idx = 0 @@ -1911,7 +1912,11 @@ def emit_code_with_chunk( if node_idx in chunk_starts: within_chunk_region = True region_idx = chunk_starts.index(node_idx) - + # add prepose nodes + for i in chunk_prepose_nodes[region_idx]: + prepose_node = node_list[_find_idx_by_name(i.name, node_list)] + emit_node_func(prepose_node, body) + delete_unused_value_func(prepose_node, body, chunk_inputs_names) # add for loop body.append( _gen_loop_start( @@ -1922,20 +1927,22 @@ def emit_code_with_chunk( ) if within_chunk_region: - emit_node_func(node, body) - # replace input var with chunk var - for input_node_idx, input_node in enumerate(chunk_inputs[region_idx]): - for idx, dim in chunk_inputs_dim[region_idx][input_node_idx].items(): - if idx == node_idx: - chunk_slice = _gen_chunk_slice_dim( - dim, "chunk_idx", _get_node_shape(input_node) - ) - body[-1] = _replace_name( - body[-1], input_node.name, input_node.name + chunk_slice - ) - body[-1] = " " + body[-1] - delete_unused_value_func(node, body, chunk_inputs_names) - + if any(node.name == i.name for i in chunk_prepose_nodes[region_idx]): + pass + else: + emit_node_func(node, body) + # replace input var with chunk var + for input_node_idx, input_node in enumerate(chunk_inputs[region_idx]): + for idx, dim in chunk_inputs_dim[region_idx][input_node_idx].items(): + if idx == node_idx: + chunk_slice = _gen_chunk_slice_dim( + dim, "chunk_idx", _get_node_shape(input_node) + ) + body[-1] = _replace_name( + body[-1], input_node.name, input_node.name + chunk_slice + ) + body[-1] = " " + body[-1] + delete_unused_value_func(node, body, chunk_inputs_names) else: emit_node_func(node, body) if node_idx not in chunk_inputs: