From ae27a8b26d7a36a3d9215fc6fd1db92982bdeef7 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Fri, 6 Jan 2023 14:57:33 +0800 Subject: [PATCH] seperate flow tracer --- colossalai/autochunk/index_tracer.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/colossalai/autochunk/index_tracer.py b/colossalai/autochunk/index_tracer.py index 7a86f3c99..0323e3a7e 100644 --- a/colossalai/autochunk/index_tracer.py +++ b/colossalai/autochunk/index_tracer.py @@ -745,14 +745,7 @@ class IndexTracer(object): next_node_list.append(arg_node) return True - def flow_search(self, start_idx, start_dim, end_idx, end_dim): - inputs, outputs = find_chunk_compute_input_and_output_nodes( - self.node_list[start_idx : end_idx + 1] - ) - # only single ouput - if len(outputs) > 1: - return None - + def _get_all_node_info(self, end_dim, start_idx, end_idx): cur_node_list = [self.node_list[end_idx]] # start from the last node all_node_info = {cur_node_list[0]: {"chunk_dim": end_dim, "fix_dim": []}} @@ -763,7 +756,6 @@ class IndexTracer(object): # get cur node info cur_node_chunk_dim = all_node_info[cur_node]["chunk_dim"] cur_node_fix_dim = all_node_info[cur_node]["fix_dim"] - cur_node_idx = find_idx_by_name(cur_node.name, self.node_list) if cur_node_chunk_dim: cur_node_compute = self._find_compute_trace_from_node(cur_node) cur_node_source = self._find_source_trace_from_node(cur_node) @@ -818,6 +810,20 @@ class IndexTracer(object): else: raise NotImplementedError() cur_node_list = next_node_list + return all_node_info + + def flow_search(self, start_idx, start_dim, end_idx, end_dim): + inputs, outputs = find_chunk_compute_input_and_output_nodes( + self.node_list[start_idx : end_idx + 1] + ) + # only single ouput + if len(outputs) > 1: + return None + + # get every node's chunk dim and fix dim + all_node_info = self._get_all_node_info(end_dim, start_idx, end_idx) + if all_node_info is None: + return None inputs_dim = [] remove_inputs = []