diff --git a/colossalai/autochunk/index_tracer.py b/colossalai/autochunk/index_tracer.py index 0323e3a7e..221217e2d 100644 --- a/colossalai/autochunk/index_tracer.py +++ b/colossalai/autochunk/index_tracer.py @@ -812,19 +812,7 @@ class IndexTracer(object): cur_node_list = next_node_list return all_node_info - def flow_search(self, start_idx, start_dim, end_idx, end_dim): - inputs, outputs = find_chunk_compute_input_and_output_nodes( - self.node_list[start_idx : end_idx + 1] - ) - # only single ouput - if len(outputs) > 1: - return None - - # get every node's chunk dim and fix dim - all_node_info = self._get_all_node_info(end_dim, start_idx, end_idx) - if all_node_info is None: - return None - + def _get_input_nodes_dim(self, inputs, start_idx, end_idx, all_node_info): inputs_dim = [] remove_inputs = [] for input_node in inputs: @@ -841,7 +829,7 @@ class IndexTracer(object): if input_node_idx in user_source: input_dict[user_idx] = user_source[input_node_idx] else: - return None + return None, None if len(input_dict) == 0: remove_inputs.append(input_node) else: @@ -849,6 +837,25 @@ class IndexTracer(object): for i in remove_inputs: if i in inputs: inputs.remove(i) + return inputs, inputs_dim + + def flow_search(self, start_idx, start_dim, end_idx, end_dim): + inputs, outputs = find_chunk_compute_input_and_output_nodes( + self.node_list[start_idx : end_idx + 1] + ) + # only single ouput + if len(outputs) > 1: + return None + + # get every node's chunk dim and fix dim + all_node_info = self._get_all_node_info(end_dim, start_idx, end_idx) + if all_node_info is None: + return None + + # get input nodes' chunk dim + inputs, inputs_dim = self._get_input_nodes_dim(inputs, start_idx, end_idx, all_node_info) + if inputs is None: + return None chunk_info = { "region": (start_idx, end_idx),