diff --git a/chunk_codegen.py b/chunk_codegen.py index 40196285e..e2786d5e2 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -857,112 +857,6 @@ class FlowTracer(object): ) return self.flow_trace - def _detect_flow( - self, start_idx, start_dim, end_idx, end_dim, index_tracer: IndexTracer - ): - inputs, outputs = _find_chunk_compute_input_and_output_nodes( - self.node_list[start_idx : end_idx + 1] - ) - chunk_info = { - "region": (start_idx, end_idx), - "inputs": inputs, - "inputs_non_chunk": [], - "inputs_dim": start_dim, - "outputs": outputs, - "outputs_dim": end_dim, - "args": {}, - } - flow_block = False - - # TODO don't allow multi outputs now - if len(outputs) > 1: - flow_block = True - return flow_block, chunk_info - - # for idx in range(start_idx, end_idx + 1): - # node = self.node_list[idx] - # mix_flow_node = self._get_flow_mix_node(node) - # if mix_flow_node is None: - # continue - - # # if there is a flow mix, op must be in [mul, add, matmul] - # # element-wise op requires dim to be equal in every dim - # if any(n in node.name for n in ["mul", "add"]): - # for i in node.args: - # if type(i) == type(mix_flow_node) and i != mix_flow_node: - # main_flow_var = i - # # if mix flow is a broadcast in chunk dim, - # # TODO: need to move that flow out of the chunk - # mix_flow_node_dim = index_tracer.get_node_chunk_dim( - # self.node_list[end_idx], end_dim, node - # ) - # # TODO: we need to loop every dim - # if isinstance(mix_flow_node_dim, list): - # mix_flow_node_dim = mix_flow_node_dim[0] - # if mix_flow_node_dim is None: - # flow_block = True - # break - # if _get_node_shape(mix_flow_node)[mix_flow_node_dim] == 1: - # flow_block = False - # for i in self._get_same_flow_node( - # chunk_info["inputs"], mix_flow_node - # ): - # chunk_info["inputs"].remove(i) - # # else, we need to chunk mix var as well - # else: - # # TODO chunk another value - # flow_block = True - # break - # else: - # raise NotImplementedError("%s not implemented" % node.name) - # if flow_block: - # flow_block = True - # return flow_block, chunk_info - - inputs_dim = [] - remove_inputs = [] - for input_node in chunk_info["inputs"]: - input_dict = {} - for user in input_node.users.keys(): - if _is_non_compute_node(user): - continue - user_idx = _find_idx_by_name(user.name, self.node_list) - dim = None - if start_dim <= user_idx < end_idx: - dim = index_tracer.get_node_chunk_dim( - self.node_list[end_idx], end_dim, input_node - ) - # TODO: we need to loop every dim - if isinstance(dim, list): - dim = dim[0] - elif user_idx == end_idx: - dim = end_dim - # n has relation with chunk dim - if dim is not None and _get_node_shape(user)[dim] != 1: - input_dict[user_idx] = dim - if len(input_dict) == 0: - remove_inputs.append(input_node) - else: - inputs_dim.append(input_dict) - chunk_info["inputs_dim"] = inputs_dim - for i in remove_inputs: - if i in chunk_info["inputs"]: - chunk_info["inputs"].remove(i) - - duplicate_result, duplicate_dim = index_tracer.check_index_duplicate( - chunk_info, return_dim=True - ) - - # we need to log input nodes to avoid deleteing them in the loop - non_chunk_inputs = _find_chunk_all_input_nodes( - self.node_list[start_idx : end_idx + 1] - ) - for i in non_chunk_inputs: - if i not in chunk_info["inputs"]: - chunk_info["inputs_non_chunk"].append(i) - - return flow_block, chunk_info - def _assgin_single_node_flow( self, arg_node,