diff --git a/colossalai/autochunk/autochunk_codegen.py b/colossalai/autochunk/autochunk_codegen.py index ddf64dc8f..82937db9f 100644 --- a/colossalai/autochunk/autochunk_codegen.py +++ b/colossalai/autochunk/autochunk_codegen.py @@ -25,7 +25,7 @@ if AUTOCHUNK_AVAILABLE: from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg from .search_chunk import SearchChunk -from .utils import delete_free_var_from_last_use, find_idx_by_name, get_logger, get_node_shape +from .utils import delete_free_var_from_last_use, get_logger, get_node_name, get_node_shape def _gen_chunk_slice_dim(chunk_dim: int, chunk_indice_name: str, shape: List) -> str: @@ -51,7 +51,7 @@ def _gen_chunk_slice_dim(chunk_dim: int, chunk_indice_name: str, shape: List) -> return new_shape -def _gen_loop_start(chunk_input: List[Node], chunk_output: Node, chunk_ouput_dim: int, chunk_size=2) -> str: +def _gen_loop_start(chunk_input: List[Node], chunk_output: List[Node], chunk_ouput_dim: int, chunk_size=2) -> str: """ Generate chunk loop start @@ -70,22 +70,28 @@ def _gen_loop_start(chunk_input: List[Node], chunk_output: Node, chunk_ouput_dim context (str): generated str """ input_node = chunk_input[0] - out_shape = get_node_shape(chunk_output) - out_str = str(list(out_shape)) - context = ( - "chunk_result = torch.empty(%s, dtype=%s.dtype, device=%s.device); chunk_size = %d\nfor chunk_idx in range" % - (out_str, input_node.name, input_node.name, chunk_size)) - context += "(0, %d, chunk_size):\n" % (out_shape[chunk_ouput_dim]) + + context = "" + for i in range(len(chunk_output)): + shape_str = str(list(get_node_shape(chunk_output[i]))) + if get_node_name(chunk_output[i]) == "split": + tensor_str = "torch.empty(%s, dtype=%s.dtype, device=%s.device), " % (shape_str, input_node.name, + input_node.name) + tensor_str = tensor_str * len(chunk_output[i].meta['tensor_meta']) + tensor_str = "[" + tensor_str[:-2] + "]" + context += "%s = %s; " % (chunk_output[i].name, tensor_str) + else: + context += "%s = torch.empty(%s, dtype=%s.dtype, device=%s.device); " % (chunk_output[i].name, shape_str, + input_node.name, input_node.name) + + out_shape = get_node_shape(chunk_output[0]) + chunk_shape = out_shape[chunk_ouput_dim[0]] + context += "chunk_size = %d\nfor chunk_idx in range(0, %d, chunk_size):\n" % (chunk_size, chunk_shape) return context -def _gen_loop_end( - chunk_inputs: List[Node], - chunk_non_compute_inputs: List[Node], - chunk_outputs: Node, - chunk_outputs_dim: int, - node_list: List[Node], -) -> str: +def _gen_loop_end(chunk_inputs: List[Node], chunk_non_compute_inputs: List[Node], node_list: List[Node], + chunk_outputs_idx: int, chunk_outputs_non_tensor: List[Node], search_chunk: SearchChunk) -> str: """ Generate chunk loop end @@ -102,22 +108,13 @@ def _gen_loop_end( Returns: context (str): generated str """ - chunk_outputs_name = chunk_outputs.name - chunk_outputs_idx = find_idx_by_name(chunk_outputs_name, node_list) - chunk_output_shape = chunk_outputs.meta["tensor_meta"].shape - chunk_slice = _gen_chunk_slice_dim(chunk_outputs_dim, "chunk_idx", chunk_output_shape) - context = " chunk_result%s = %s; %s = None\n" % ( - chunk_slice, - chunk_outputs_name, - chunk_outputs_name, - ) - context += (chunk_outputs_name + " = chunk_result; chunk_result = None; chunk_size = None") - + context = "chunk_size = None" # determine if its the last use for chunk input for chunk_input in chunk_inputs + chunk_non_compute_inputs: - if all([find_idx_by_name(user.name, node_list) <= chunk_outputs_idx for user in chunk_input.users.keys()]): + if all([search_chunk.node_mgr.find_node_idx(user) <= chunk_outputs_idx for user in chunk_input.users.keys()]): context += "; %s = None" % chunk_input.name - + for chunk_output_non_tensor, chunk_output_non_tensor_val in chunk_outputs_non_tensor.items(): + context += "; %s = %s" % (chunk_output_non_tensor.name, chunk_output_non_tensor_val) context += "\n" return context @@ -158,7 +155,7 @@ def _replace_ones_like( add chunk slice for new tensor op such as ones like """ if "ones_like" in node.name: - meta_node = search_chunk.trace_indice.node_list[node_idx] + meta_node = search_chunk.node_mgr.get_node_by_idx(node_idx) chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node]["chunk_dim"] if get_node_shape(meta_node)[chunk_dim] != 1: source_node = meta_node.args[0].args[0] @@ -169,21 +166,37 @@ def _replace_ones_like( return body -def _replace_input_node( - chunk_inputs: List[Node], +def _add_node_slice( + chunk_nodes: List[Node], region_idx: int, - chunk_inputs_dim: Dict, + chunk_nodes_dim: Dict, node_idx: int, body: List[str], + node: Node, ) -> List[str]: """ add chunk slice for input nodes """ - for input_node_idx, input_node in enumerate(chunk_inputs[region_idx]): - for idx, dim in chunk_inputs_dim[region_idx][input_node_idx].items(): - if idx == node_idx: - chunk_slice = _gen_chunk_slice_dim(dim[0], "chunk_idx", get_node_shape(input_node)) - body[-1] = _replace_name(body[-1], input_node.name, input_node.name + chunk_slice) + for chunk_node_idx, chunk_node in enumerate(chunk_nodes[region_idx]): + # inputs node + if isinstance(chunk_nodes_dim[region_idx][chunk_node_idx], dict): + for idx, dim in chunk_nodes_dim[region_idx][chunk_node_idx].items(): + if idx == node_idx: + chunk_slice = _gen_chunk_slice_dim(dim[0], "chunk_idx", get_node_shape(chunk_node)) + body[-1] = _replace_name(body[-1], chunk_node.name, chunk_node.name + chunk_slice) + # outputs node + else: + if chunk_node.name == node.name or (chunk_node.name in [i.name for i in node.all_input_nodes]): + chunk_slice = _gen_chunk_slice_dim(chunk_nodes_dim[region_idx][chunk_node_idx], "chunk_idx", + get_node_shape(chunk_node)) + if get_node_name(chunk_node) == "split": + split_chunk_slice = "" + for i in range(len(chunk_node.meta['tensor_meta'])): + split_chunk_slice += "%s[%d]%s, " % (chunk_node.name, i, chunk_slice) + split_chunk_slice = split_chunk_slice[:-2] + body[-1] = _replace_name(body[-1], chunk_node.name, split_chunk_slice) + else: + body[-1] = _replace_name(body[-1], chunk_node.name, chunk_node.name + chunk_slice) return body @@ -222,7 +235,8 @@ def emit_code_with_chunk( chunk_inputs_names = [j.name for i in chunk_inputs for j in i] + [j.name for i in chunk_inputs_non_chunk for j in i] # chunk outputs - chunk_outputs = [i["outputs"][0] for i in chunk_infos] + chunk_outputs = [i["outputs"] for i in chunk_infos] + chunk_outputs_non_tensor = [i["outputs_non_tensor"] for i in chunk_infos] chunk_outputs_dim = [i["outputs_dim"] for i in chunk_infos] node_list = search_chunk.reorder_graph.reorder_node_list(node_list) @@ -248,7 +262,9 @@ def emit_code_with_chunk( if within_chunk_region: emit_node_func(node, body) # replace input var with chunk var - body = _replace_input_node(chunk_inputs, region_idx, chunk_inputs_dim, node_idx, body) + body = _add_node_slice(chunk_inputs, region_idx, chunk_inputs_dim, node_idx, body, node) + # replace output var with chunk var + body = _add_node_slice(chunk_outputs, region_idx, chunk_outputs_dim, node_idx, body, node) # ones like body = _replace_ones_like(search_chunk, chunk_infos, region_idx, node_idx, node, body) # reassgin reshape size @@ -263,13 +279,8 @@ def emit_code_with_chunk( # generate chunk region end if node_idx in chunk_ends: body.append( - _gen_loop_end( - chunk_inputs[region_idx], - chunk_inputs_non_chunk[region_idx], - chunk_outputs[region_idx], - chunk_outputs_dim[region_idx], - node_list, - )) + _gen_loop_end(chunk_inputs[region_idx], chunk_inputs_non_chunk[region_idx], node_list, + chunk_ends[region_idx], chunk_outputs_non_tensor[region_idx], search_chunk)) within_chunk_region = False node_idx += 1 diff --git a/colossalai/autochunk/estimate_memory.py b/colossalai/autochunk/estimate_memory.py index a03a5413b..f457696e6 100644 --- a/colossalai/autochunk/estimate_memory.py +++ b/colossalai/autochunk/estimate_memory.py @@ -6,7 +6,7 @@ from torch.fx.node import Node, map_arg from colossalai.fx.profiler import activation_size, parameter_size -from .utils import delete_free_var_from_last_use, find_idx_by_name, get_node_shape, is_non_memory_node +from .utils import NodeMgr, delete_free_var_from_last_use, get_node_shape, is_non_memory_node class EstimateMemory(object): @@ -14,8 +14,8 @@ class EstimateMemory(object): Estimate memory with chunk """ - def __init__(self) -> None: - pass + def __init__(self, node_mgr: NodeMgr) -> None: + self.node_mgr = node_mgr def _get_meta_node_size(self, x): x = x.meta["tensor_meta"] @@ -78,7 +78,7 @@ class EstimateMemory(object): nodes_to_delete = [] for chunk_input in chunk_inputs + chunk_inputs_non_chunk: chunk_input_users = chunk_input.users.keys() - chunk_input_users_idx = [find_idx_by_name(i.name, node_list) for i in chunk_input_users] + chunk_input_users_idx = [self.node_mgr.find_node_idx(i) for i in chunk_input_users] if all(i <= chunk_end_idx for i in chunk_input_users_idx): if chunk_input not in nodes_to_delete: nodes_to_delete.append(chunk_input) @@ -212,7 +212,7 @@ class EstimateMemory(object): chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos] chunk_inputs_names = [j.name for i in chunk_inputs for j in i ] + [j.name for i in chunk_inputs_non_chunk for j in i] - chunk_outputs = [i["outputs"][0] for i in chunk_infos] + chunk_outputs = [i["outputs"] for i in chunk_infos] chunk_node_dim = [i["node_chunk_dim"] for i in chunk_infos] chunk_sizes = [i["chunk_size"] if "chunk_size" in i else 1 for i in chunk_infos] @@ -221,7 +221,7 @@ class EstimateMemory(object): if use_chunk and idx in chunk_starts: chunk_within = True chunk_region_idx = chunk_starts.index(idx) - act_memory += self._get_output_node_size(chunk_outputs[chunk_region_idx]) / (1024**2) + act_memory += sum(self._get_output_node_size(i) for i in chunk_outputs[chunk_region_idx]) / (1024**2) # determine chunk ratio for current node if chunk_within: diff --git a/colossalai/autochunk/reorder_graph.py b/colossalai/autochunk/reorder_graph.py index 0343e52ee..3b00d47fb 100644 --- a/colossalai/autochunk/reorder_graph.py +++ b/colossalai/autochunk/reorder_graph.py @@ -1,5 +1,5 @@ from .trace_indice import TraceIndice -from .utils import find_idx_by_name +from .utils import NodeMgr class ReorderGraph(object): @@ -7,31 +7,27 @@ class ReorderGraph(object): Reorder node list and indice trace list """ - def __init__(self, trace_indice: TraceIndice) -> None: + def __init__(self, trace_indice: TraceIndice, node_mgr: NodeMgr) -> None: self.trace_indice = trace_indice - self.all_reorder_map = { - i: i for i in range(len(self.trace_indice.indice_trace_list)) - } + self.node_mgr = node_mgr + self.all_reorder_map = {i: i for i in range(len(self.node_mgr.get_node_list()))} def _get_reorder_map(self, chunk_info): - reorder_map = {i: i for i in range(len(self.trace_indice.node_list))} + reorder_map = {i: i for i in range(len(self.node_mgr.get_node_list()))} chunk_region_start = chunk_info["region"][0] chunk_region_end = chunk_info["region"][1] chunk_prepose_nodes = chunk_info["args"]["prepose_nodes"] - chunk_prepose_nodes_idx = [ - find_idx_by_name(i.name, self.trace_indice.node_list) - for i in chunk_prepose_nodes - ] + chunk_prepose_nodes_idx = [self.node_mgr.find_node_idx(i) for i in chunk_prepose_nodes] # put prepose nodes ahead for idx, n in enumerate(chunk_prepose_nodes): n_idx = chunk_prepose_nodes_idx[idx] reorder_map[n_idx] = chunk_region_start + idx # put other nodes after prepose nodes - for n in self.trace_indice.node_list[chunk_region_start : chunk_region_end + 1]: + for n in self.node_mgr.get_node_slice_by_idx(chunk_region_start, chunk_region_end + 1): if n in chunk_prepose_nodes: continue - n_idx = find_idx_by_name(n.name, self.trace_indice.node_list) + n_idx = self.node_mgr.find_node_idx(n) pos = sum([n_idx < i for i in chunk_prepose_nodes_idx]) reorder_map[n_idx] = n_idx + pos @@ -44,7 +40,7 @@ class ReorderGraph(object): chunk_info["region"][1], ) new_inputs_dim = [] - for idx, input_dim in enumerate(chunk_info["inputs_dim"]): + for _, input_dim in enumerate(chunk_info["inputs_dim"]): new_input_dim = {} for k, v in input_dim.items(): new_input_dim[reorder_map[k]] = v @@ -57,16 +53,14 @@ class ReorderGraph(object): self.all_reorder_map[origin_idx] = reorder_map[map_idx] def _reorder_self_node_list(self, reorder_map): - new_node_list = [None for _ in range(len(self.trace_indice.node_list))] + new_node_list = [None for _ in range(len(self.node_mgr.get_node_list()))] for old_idx, new_idx in reorder_map.items(): - new_node_list[new_idx] = self.trace_indice.node_list[old_idx] - self.trace_indice.node_list = new_node_list + new_node_list[new_idx] = self.node_mgr.get_node_by_idx(old_idx) + self.node_mgr.update_node_list(new_node_list) def _reorder_idx_trace(self, reorder_map): # reorder list - new_idx_trace_list = [ - None for _ in range(len(self.trace_indice.indice_trace_list)) - ] + new_idx_trace_list = [None for _ in range(len(self.trace_indice.indice_trace_list))] for old_idx, new_idx in reorder_map.items(): new_idx_trace_list[new_idx] = self.trace_indice.indice_trace_list[old_idx] self.trace_indice.indice_trace_list = new_idx_trace_list diff --git a/colossalai/autochunk/search_chunk.py b/colossalai/autochunk/search_chunk.py index 720f3d925..0278e03f7 100644 --- a/colossalai/autochunk/search_chunk.py +++ b/colossalai/autochunk/search_chunk.py @@ -9,6 +9,7 @@ from .select_chunk import SelectChunk from .trace_flow import TraceFlow from .trace_indice import TraceIndice from .utils import ( + NodeMgr, find_chunk_compute_input_and_output_nodes, get_logger, get_node_shape, @@ -49,15 +50,17 @@ class SearchChunk(object): def __init__(self, gm, max_memory=None, print_mem=False, print_progress=False) -> None: self.print_mem = print_mem self.print_progress = print_progress - self.trace_indice = TraceIndice(list(gm.graph.nodes)) - self.estimate_memory = EstimateMemory() + self.node_mgr = NodeMgr(gm) + self.trace_indice = TraceIndice(self.node_mgr) + self.estimate_memory = EstimateMemory(self.node_mgr) self._init_trace() - self.trace_flow = TraceFlow(self.trace_indice) - self.reorder_graph = ReorderGraph(self.trace_indice) + self.trace_flow = TraceFlow(self.trace_indice, self.node_mgr) + self.reorder_graph = ReorderGraph(self.trace_indice, self.node_mgr) self.select_chunk = SelectChunk( self.trace_indice, self.estimate_memory, self.reorder_graph, + self.node_mgr, max_memory=max_memory, ) @@ -67,7 +70,7 @@ class SearchChunk(object): reduce the computation complexity of trace_indice """ # find all max ranges - active_nodes = self.estimate_memory.get_active_nodes(self.trace_indice.node_list) + active_nodes = self.estimate_memory.get_active_nodes(self.node_mgr.get_node_list()) cur_node_idx = len(self._get_free_var_idx()) max_chunk_region_list = [] while True: @@ -100,7 +103,7 @@ class SearchChunk(object): free_var_idx (List): all indexs of free vars """ free_var_idx = [] - for idx, n in enumerate(self.trace_indice.node_list): + for idx, n in enumerate(self.node_mgr.get_node_list()): if n.op == "placeholder" and get_node_shape(n) is not None: free_var_idx.append(idx) return free_var_idx @@ -164,6 +167,44 @@ class SearchChunk(object): chunk_region_end = region[0] - 1 return chunk_region_start, chunk_region_end + def _find_chunk_info(self, input_trace, output_trace, start_idx, end_idx) -> List: + """ + Find chunk info for a region. + + We are given the region start and region end, and need to find out all chunk info for it. + We first loop every dim of start node and end node, to see if we can find dim pair, + which is linked in a flow and not computed. + If found, we then search flow in the whole region to find out all chunk infos. + + Args: + input_trace (List): node's input trace in region + output_trace (List): node's output trace in region + start_idx (int): region start node index + end_idx (int): region end node index + + Returns: + chunk_infos: possible regions found + """ + start_traces = input_trace[start_idx] + if len(start_traces) > 1: # TODO need to be removed + return [] + end_trace = output_trace[end_idx] + end_node = self.node_mgr.get_node_by_idx(end_idx) + + chunk_infos = [] + for end_dim, _ in enumerate(end_trace["indice"]): + for start_node, start_trace in start_traces.items(): + for start_dim, _ in enumerate(start_trace["indice"]): + if not self.trace_flow.check_region_start_end(start_node, start_dim, start_idx, end_node, end_dim, + end_idx): + continue + # flow search + chunk_info = self.trace_flow.flow_search(start_idx, start_dim, end_idx, end_dim) + if chunk_info is None: + continue + chunk_infos.append(chunk_info) + return chunk_infos + def _search_possible_chunk_regions(self, max_chunk_region: Tuple, peak_node: Node) -> List: """ Search every possible region within the max chunk region. @@ -178,7 +219,7 @@ class SearchChunk(object): possible_chunk_region = [] output_trace = copy.deepcopy(self.trace_indice.indice_trace_list) input_trace = [] # trace of a node's input nodes - for _, n in enumerate(self.trace_indice.node_list): + for _, n in enumerate(self.node_mgr.get_node_list()): cur_trace = {} for arg in n.args: if type(arg) == type(n) and not is_non_compute_node_except_placeholder(arg): @@ -188,11 +229,11 @@ class SearchChunk(object): for start_idx in range(max_chunk_region[0], peak_node + 1): for end_idx in range(peak_node, max_chunk_region[1] + 1): # skip non compute nodes - if is_non_compute_node(self.trace_indice.node_list[start_idx]) or is_non_compute_node( - self.trace_indice.node_list[end_idx]): + if is_non_compute_node(self.node_mgr.get_node_by_idx(start_idx)) or is_non_compute_node( + self.node_mgr.get_node_by_idx(end_idx)): continue # select free dim - chunk_info = self.trace_flow.find_chunk_info(input_trace, output_trace, start_idx, end_idx) + chunk_info = self._find_chunk_info(input_trace, output_trace, start_idx, end_idx) if len(chunk_info) > 0: possible_chunk_region.extend(chunk_info) return possible_chunk_region @@ -254,7 +295,7 @@ class SearchChunk(object): init_mem_peak, _, active_node, - ) = self.estimate_memory.estimate_chunk_inference_mem(self.trace_indice.node_list) + ) = self.estimate_memory.estimate_chunk_inference_mem(self.node_mgr.get_node_list()) mem_peak = init_mem_peak while True: @@ -267,7 +308,7 @@ class SearchChunk(object): mem_peak, _, active_node, - ) = self.estimate_memory.estimate_chunk_inference_mem(self.trace_indice.node_list, chunk_infos) + ) = self.estimate_memory.estimate_chunk_inference_mem(self.node_mgr.get_node_list(), chunk_infos) if self.print_progress: get_logger().info("AutoChunk find chunk region %d = (%d, %d)" % @@ -277,5 +318,7 @@ class SearchChunk(object): break if self.print_mem: self.print_mem = False - self.estimate_memory.estimate_chunk_inference_mem(self.trace_indice.node_list, chunk_infos, print_mem=True) + self.estimate_memory.estimate_chunk_inference_mem(self.node_mgr.get_node_list(), + chunk_infos, + print_mem=True) return chunk_infos diff --git a/colossalai/autochunk/select_chunk.py b/colossalai/autochunk/select_chunk.py index 1f3a95727..1bb7d318c 100644 --- a/colossalai/autochunk/select_chunk.py +++ b/colossalai/autochunk/select_chunk.py @@ -1,7 +1,7 @@ from .estimate_memory import EstimateMemory from .reorder_graph import ReorderGraph from .trace_indice import TraceIndice -from .utils import is_non_compute_node +from .utils import NodeMgr, is_non_compute_node class SelectChunk(object): @@ -11,11 +11,13 @@ class SelectChunk(object): trace_indice: TraceIndice, estimate_memory: EstimateMemory, reorder_graph: ReorderGraph, + node_mgr: NodeMgr, max_memory=None, ): self.trace_indice = trace_indice self.estimate_memory = estimate_memory self.reorder_graph = reorder_graph + self.node_mgr = node_mgr if max_memory is not None: self.stratge = "fit_memory" self.max_memory = max_memory # MB @@ -68,7 +70,7 @@ class SelectChunk(object): regions_dict = [] for region in possible_chunk_regions: cur_region = region.copy() - cur_node_list, cur_region = self.reorder_graph.tmp_reorder(self.trace_indice.node_list, cur_region) + cur_node_list, cur_region = self.reorder_graph.tmp_reorder(self.node_mgr.get_node_list(), cur_region) cur_chunk_infos = chunk_infos + [cur_region] cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(cur_node_list, cur_chunk_infos)[0] cur_chunk_region_peak = cur_mem_peak[max_possible_chunk_region[0]:max_possible_chunk_region[1] + 1] @@ -134,7 +136,7 @@ class SelectChunk(object): def _get_compute_node_num(self, start, end): count = 0 - for i in self.trace_indice.node_list[start:end + 1]: + for i in self.node_mgr.get_node_slice_by_idx(start, end + 1): if not is_non_compute_node(i): count += 1 return count @@ -161,7 +163,7 @@ class SelectChunk(object): regions_dict_list = [] for region in possible_chunk_regions: cur_region = region.copy() - cur_node_list, cur_region = self.reorder_graph.tmp_reorder(self.trace_indice.node_list, cur_region) + cur_node_list, cur_region = self.reorder_graph.tmp_reorder(self.node_mgr.get_node_list(), cur_region) cur_chunk_infos = chunk_infos + [cur_region] cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(cur_node_list, cur_chunk_infos)[0] cur_chunk_region_peak = cur_mem_peak[max_possible_chunk_region[0]:max_possible_chunk_region[1] + 1] diff --git a/colossalai/autochunk/trace_flow.py b/colossalai/autochunk/trace_flow.py index df7343764..11dbb266d 100644 --- a/colossalai/autochunk/trace_flow.py +++ b/colossalai/autochunk/trace_flow.py @@ -4,9 +4,10 @@ from torch.fx.node import Node from .trace_indice import TraceIndice from .utils import ( + NodeMgr, find_chunk_all_input_nodes, find_chunk_compute_input_and_output_nodes, - find_idx_by_name, + find_tensor_shape_node, flat_list, get_node_name, get_node_shape, @@ -16,8 +17,9 @@ from .utils import ( class TraceFlow(object): - def __init__(self, trace_indice: TraceIndice) -> None: + def __init__(self, trace_indice: TraceIndice, node_mgr: NodeMgr) -> None: self.trace_indice = trace_indice + self.node_mgr = node_mgr def check_index_source(self, start_dim, start_node, start_idx, end_dim, end_node): """ @@ -31,7 +33,8 @@ class TraceFlow(object): Returns: bool: True if check pass """ - start_node_idx = find_idx_by_name(start_node.name, self.trace_indice.node_list) + # we use start_node_idx instead of real chunk index + start_node_idx = self.node_mgr.find_node_idx(start_node) end_node_trace = self.trace_indice._find_trace_from_node(end_node) end_node_trace_source = end_node_trace["source"][end_dim] sorted_source = sorted(end_node_trace_source.items(), key=lambda d: d[0], reverse=True) @@ -39,7 +42,7 @@ class TraceFlow(object): if node_idx == start_node_idx and start_dim in node_dim: return True # it means we meet a node outside the loop, and the node is not input node - if node_idx < start_idx: + if node_idx < start_node_idx: return False return False @@ -61,29 +64,12 @@ class TraceFlow(object): return False return True - def get_node_chunk_dim(self, node_from, node_from_dim, node_to): - node_from_source = self.trace_indice._find_source_trace_from_node(node_from) - dim_source = node_from_source[node_from_dim] - node_to_idx = find_idx_by_name(node_to.name, self.trace_indice.node_list) - for k, v in dim_source.items(): - if k == node_to_idx: - return v - return None - - def _find_inherit_dim(self, input_node, input_dim, node): - input_node_idx = find_idx_by_name(input_node.name, self.trace_indice.node_list) - node_trace_source = self.trace_indice._find_source_trace_from_node(node) - for node_dim in range(len(get_node_shape(node))): - if (input_node_idx in node_trace_source[node_dim] - and input_dim[0] in node_trace_source[node_dim][input_node_idx]): - return node_dim - return None - def _assgin_single_node_flow( self, arg_node: Node, start_idx: int, end_idx: int, + cur_node: Node, cur_node_dim: int, cur_node_compute: Dict, cur_node_source: Dict, @@ -109,7 +95,7 @@ class TraceFlow(object): Returns: bool: True if this node can be added to the flow, vice versa. """ - arg_idx = find_idx_by_name(arg_node.name, self.trace_indice.node_list) + arg_idx = self.node_mgr.find_node_idx(arg_node) # arg in chunk range or be inputs if not (start_idx <= arg_idx < end_idx): return True @@ -126,6 +112,11 @@ class TraceFlow(object): # chunk dim should be None if shape size is 1 if get_node_shape(arg_node)[arg_dim] == 1: arg_dim = None + # chunk shape should equal cur node + elif get_node_shape(arg_node)[arg_dim] != 1: + if cur_node_dim is not None and get_node_shape(cur_node)[cur_node_dim] != 1: + if get_node_shape(arg_node)[arg_dim] != get_node_shape(cur_node)[cur_node_dim]: + return False else: arg_dim = None @@ -150,7 +141,7 @@ class TraceFlow(object): return True def _get_all_node_info(self, end_dim, start_idx, end_idx): - cur_node_list = [self.trace_indice.node_list[end_idx]] # start from the last node + cur_node_list = [self.node_mgr.get_node_by_idx(end_idx)] # start from the last node all_node_info = {cur_node_list[0]: {"chunk_dim": end_dim, "fix_dim": []}} while len(cur_node_list) > 0: @@ -178,6 +169,7 @@ class TraceFlow(object): arg, start_idx, end_idx, + cur_node, cur_node_chunk_dim, cur_node_compute, cur_node_source, @@ -194,7 +186,7 @@ class TraceFlow(object): for arg in arg_list: if get_node_shape(arg) is None: continue - if not (start_idx <= find_idx_by_name(arg.name, self.trace_indice.node_list) < end_idx): + if not (start_idx <= self.node_mgr.find_node_idx(arg) < end_idx): continue arg_chunk_dim = all_node_info[arg]["chunk_dim"] arg_fix_dim = all_node_info[arg]["fix_dim"] @@ -232,7 +224,7 @@ class TraceFlow(object): remove_inputs = [] for input_node in inputs: input_dict = {} - input_node_idx = find_idx_by_name(input_node.name, self.trace_indice.node_list) + input_node_idx = self.node_mgr.find_node_idx(input_node) for user in input_node.users.keys(): # skip non compute if is_non_compute_node(user): @@ -240,7 +232,7 @@ class TraceFlow(object): # untraced node, mostly non compute if user not in all_node_info: continue - user_idx = find_idx_by_name(user.name, self.trace_indice.node_list) + user_idx = self.node_mgr.find_node_idx(user) if start_idx <= user_idx <= end_idx: chunk_dim = all_node_info[user]["chunk_dim"] if chunk_dim is not None: @@ -262,7 +254,7 @@ class TraceFlow(object): inputs.remove(i) return inputs, inputs_dim - def _get_prepose_nodes(self, all_node_info: Dict, start_idx: int, end_idx: int) -> List[Node]: + def _get_prepose_nodes(self, all_node_info: Dict, start_idx: int, end_idx: int, chunk_info) -> List[Node]: """ get all useless nodes in chunk region and prepose them @@ -279,8 +271,11 @@ class TraceFlow(object): for node, node_info in all_node_info.items(): if node_info["chunk_dim"] is None: maybe_prepose_nodes.append(node) + for node in self.node_mgr.get_node_slice_by_idx(start_idx, end_idx): + if node not in all_node_info and node not in chunk_info["outputs"]: + maybe_prepose_nodes.append(node) maybe_prepose_nodes.sort( - key=lambda x: find_idx_by_name(x.name, self.trace_indice.node_list), + key=lambda x: self.node_mgr.find_node_idx(x), reverse=True, ) # from last node to first node prepose_nodes = [] @@ -303,8 +298,7 @@ class TraceFlow(object): if type(cur_prepose_node_arg) != type(cur_prepose_node): continue # out of loop - if not (start_idx <= find_idx_by_name(cur_prepose_node_arg.name, self.trace_indice.node_list) < - end_idx): + if not (start_idx <= self.node_mgr.find_node_idx(cur_prepose_node_arg) < end_idx): continue # compute op in loop elif cur_prepose_node_arg in all_node_info: @@ -328,13 +322,12 @@ class TraceFlow(object): if n in maybe_prepose_nodes: maybe_prepose_nodes.remove(n) # sort by index - prepose_nodes.sort(key=lambda x: find_idx_by_name(x.name, self.trace_indice.node_list)) - - return prepose_nodes + prepose_nodes.sort(key=lambda x: self.node_mgr.find_node_idx(x)) + chunk_info["args"]["prepose_nodes"] = prepose_nodes def _get_non_chunk_inputs(self, chunk_info, start_idx, end_idx): # we need to log input nodes to avoid deleteing them in the loop - chunk_node_list = self.trace_indice.node_list[start_idx:end_idx + 1] + chunk_node_list = self.node_mgr.get_node_slice_by_idx(start_idx, end_idx + 1) # also need to get some prepose node's arg out of non_chunk_inputs for n in chunk_info["args"]["prepose_nodes"]: chunk_node_list.remove(n) @@ -345,34 +338,41 @@ class TraceFlow(object): return chunk_info def flow_search(self, start_idx, start_dim, end_idx, end_dim): - inputs, outputs = find_chunk_compute_input_and_output_nodes(self.trace_indice.node_list[start_idx:end_idx + 1]) - # only single ouput - if len(outputs) > 1: - return None + inputs, outputs = find_chunk_compute_input_and_output_nodes( + self.node_mgr.get_node_slice_by_idx(start_idx, end_idx + 1)) # get every node's chunk dim and fix dim all_node_info = self._get_all_node_info(end_dim, start_idx, end_idx) if all_node_info is None: return None - # get input nodes' chunk dim - inputs, inputs_dim = self._get_input_nodes_dim(inputs, start_idx, end_idx, all_node_info) - if inputs is None: - return None - chunk_info = { "region": (start_idx, end_idx), - "inputs": inputs, + "inputs": [], "inputs_non_chunk": [], - "inputs_dim": inputs_dim, - "outputs": outputs, - "outputs_dim": end_dim, + "inputs_dim": [], + "outputs": [self.node_mgr.get_node_by_idx(end_idx)], + "outputs_non_tensor": {}, + "outputs_dim": [end_dim], "node_chunk_dim": all_node_info, "args": {}, } + # find chunk info for other outputs + if len(find_tensor_shape_node(outputs)) > 1: + chunk_info = self._get_other_output_info(outputs, start_idx, start_dim, end_idx, end_dim, chunk_info) + if chunk_info is None: + return None + + # get input nodes' chunk dim + inputs, inputs_dim = self._get_input_nodes_dim(inputs, start_idx, end_idx, all_node_info) + if inputs is None: + return None + chunk_info["inputs"] = inputs + chunk_info["inputs_dim"] = inputs_dim + # move useless nodes ahead of loop - chunk_info["args"]["prepose_nodes"] = self._get_prepose_nodes(all_node_info, start_idx, end_idx) + self._get_prepose_nodes(all_node_info, start_idx, end_idx, chunk_info) # find non chunk inputs chunk_info = self._get_non_chunk_inputs(chunk_info, start_idx, end_idx) @@ -382,6 +382,63 @@ class TraceFlow(object): return chunk_info + def _get_other_output_info(self, outputs: List[Node], start_idx: int, start_dim: int, end_idx: int, end_dim: int, + chunk_info: Dict): + start_node = self.node_mgr.get_node_by_idx(start_idx) + # loop all outputs + for output in outputs: + output_legal = False + output_idx = self.node_mgr.find_node_idx(output) + # skip the origin output + if output_idx == end_idx: + continue + # skip non tensor + if get_node_shape(output) is None: + # log shape tensor + if len(output.meta['fwd_out']) > 0 and isinstance(output.meta['fwd_out'][0], int): + chunk_info["outputs_non_tensor"][output] = str(output.meta['fwd_out']) + continue + # loop every dim of outputs, try to find a legal one + for output_dim in range(len(get_node_shape(output))): + if not self.check_region_start_end(start_node, start_dim, start_idx, output, output_dim, output_idx): + continue + new_all_node_info = self._get_all_node_info(output_dim, start_idx, output_idx) + if new_all_node_info is None: + continue + # check node info legal + if self._update_chunk_info(chunk_info, new_all_node_info, output, output_dim) == True: + output_legal = True + break + # not legal + if output_legal == False: + return None + return chunk_info + + def _update_chunk_info(self, chunk_info: Dict, new_all_node_info: Dict, output: Node, output_dim: int) -> bool: + """ + check if there is conflict between new node info and old chunk info. If not, update old chunk info + """ + # check if conflict + overlap_flag = False + for k, v in new_all_node_info.items(): + if k in chunk_info["node_chunk_dim"]: + overlap_flag = True + if chunk_info["node_chunk_dim"][k]["chunk_dim"] != v["chunk_dim"]: + return False + # if no overlap, we just consider them as prepose nodes, instead of new output + if overlap_flag == False: + return True + # update chunk info + for k, v in new_all_node_info.items(): + if k in chunk_info["node_chunk_dim"]: + chunk_info["node_chunk_dim"][k]["fix_dim"] = list( + set(chunk_info["node_chunk_dim"][k]["fix_dim"] + v["fix_dim"])) + else: + chunk_info["node_chunk_dim"][k] = v + chunk_info["outputs"].append(output) + chunk_info["outputs_dim"].append(output_dim) + return True + def _reassgin_reshape_size(self, chunk_info): """ Some shape args in reshape may have changed due to chunk @@ -389,10 +446,17 @@ class TraceFlow(object): """ chunk_region = chunk_info["region"] reshape_size = {} - chunk_shape = get_node_shape(chunk_info["outputs"][0])[chunk_info["outputs_dim"]] - for node in self.trace_indice.node_list[chunk_region[0]:chunk_region[1] + 1]: + chunk_shape = get_node_shape(chunk_info["outputs"][0])[chunk_info["outputs_dim"][0]] + for node in self.node_mgr.get_node_slice_by_idx(chunk_region[0], chunk_region[1] + 1): if any(i == get_node_name(node) for i in ["reshape", "view"]): + if node in chunk_info["args"]["prepose_nodes"]: + continue + if node.args[0] in chunk_info["inputs_non_chunk"]: + continue reshape_args = flat_list(node.args[1:]) + if len(reshape_args) == 1 and get_node_shape(reshape_args[0]) is None and len( + reshape_args[0].meta['fwd_out']) > 1: + continue chunk_dim = chunk_info["node_chunk_dim"][node]["chunk_dim"] new_shape = "" for reshape_arg_dim, reshape_arg in enumerate(reshape_args): @@ -409,45 +473,8 @@ class TraceFlow(object): chunk_info["reshape_size"] = reshape_size return chunk_info - def find_chunk_info(self, input_trace, output_trace, start_idx, end_idx) -> List: - """ - Find chunk info for a region. - - We are given the region start and region end, and need to find out all chunk info for it. - We first loop every dim of start node and end node, to see if we can find dim pair, - which is linked in a flow and not computed. - If found, we then search flow in the whole region to find out all chunk infos. - - Args: - input_trace (List): node's input trace in region - output_trace (List): node's output trace in region - start_idx (int): region start node index - end_idx (int): region end node index - - Returns: - chunk_infos: possible regions found - """ - start_traces = input_trace[start_idx] - if len(start_traces) > 1: # TODO need to be removed - return [] - end_trace = output_trace[end_idx] - end_node = self.trace_indice.node_list[end_idx] - - chunk_infos = [] - for end_dim, _ in enumerate(end_trace["indice"]): - for start_node, start_trace in start_traces.items(): - for start_dim, _ in enumerate(start_trace["indice"]): - if not self._check_region_start_end(start_node, start_dim, start_idx, end_node, end_dim, end_idx): - continue - # flow search - chunk_info = self.flow_search(start_idx, start_dim, end_idx, end_dim) - if chunk_info is None: - continue - chunk_infos.append(chunk_info) - return chunk_infos - - def _check_region_start_end(self, start_node: Node, start_dim: int, start_idx: int, end_node: Node, end_dim: int, - end_idx: int) -> bool: + def check_region_start_end(self, start_node: Node, start_dim: int, start_idx: int, end_node: Node, end_dim: int, + end_idx: int) -> bool: """ check if region start and end is legal """ diff --git a/colossalai/autochunk/trace_indice.py b/colossalai/autochunk/trace_indice.py index 8f517cf2c..b591fa764 100644 --- a/colossalai/autochunk/trace_indice.py +++ b/colossalai/autochunk/trace_indice.py @@ -3,14 +3,7 @@ from typing import Dict, List, Tuple from torch.fx.node import Node -from .utils import ( - find_first_tensor_arg, - find_idx_by_name, - flat_list, - get_module_node_name, - get_node_name, - get_node_shape, -) +from .utils import NodeMgr, find_first_tensor_arg, flat_list, get_module_node_name, get_node_name, get_node_shape class TraceIndice(object): @@ -35,8 +28,8 @@ class TraceIndice(object): node_list (List) """ - def __init__(self, node_list: List[Node]) -> None: - self.node_list = node_list + def __init__(self, node_mgr: NodeMgr) -> None: + self.node_mgr = node_mgr self.indice_trace_list = self._init_indice_trace_list() self.indice_view_list = {} self.indice_count = -1 @@ -45,7 +38,7 @@ class TraceIndice(object): def _init_indice_trace_list(self) -> List: indice_trace_list = [] - for n in self.node_list: + for n in self.node_mgr.get_node_list(): if get_node_shape(n) != None: cur_trace = { "indice": [None for _ in range(len(get_node_shape(n)))], @@ -99,7 +92,7 @@ class TraceIndice(object): node_from_trace_source = self._find_source_trace_from_node(node_from) node_to_dim = self._transform_indice(node_to, node_to_dim) node_to_trace_source = self._find_source_trace_from_node(node_to) - node_from_idx = find_idx_by_name(node_from.name, self.node_list) + node_from_idx = self.node_mgr.find_node_idx(node_from) if init: node_to_trace_source[node_to_dim] = {} # add dim to cur new source @@ -200,7 +193,7 @@ class TraceIndice(object): idx (list): idx of the node compute (list): computed idx of the node. """ - node_idx = find_idx_by_name(node.name, self.node_list) + node_idx = self.node_mgr.find_node_idx(node) node_dict = self.indice_trace_list[node_idx] return node_dict @@ -214,7 +207,7 @@ class TraceIndice(object): idx (list): idx of the node compute (list): computed idx of the node. """ - node_idx = find_idx_by_name(node.name, self.node_list) + node_idx = self.node_mgr.find_node_idx(node) node_dict = self.indice_trace_list[node_idx] return node_dict["source"] @@ -227,7 +220,7 @@ class TraceIndice(object): Returns: idx (list): idx of the node """ - node_idx = find_idx_by_name(node.name, self.node_list) + node_idx = self.node_mgr.find_node_idx(node) return self.indice_trace_list[node_idx]["indice"] def _find_compute_trace_from_node(self, node: Node) -> List: @@ -239,7 +232,7 @@ class TraceIndice(object): Returns: compute (list): computed idx of the node. """ - node_idx = find_idx_by_name(node.name, self.node_list) + node_idx = self.node_mgr.find_node_idx(node) return self.indice_trace_list[node_idx]["compute"] def _assign_indice_as_input(self, node: Node, node_idx: int, input_node=None) -> None: @@ -454,8 +447,6 @@ class TraceIndice(object): node (node) node_idx (int) """ - for _ in range(len(get_node_shape(node.args[0]))): - self._add_dim(node_idx, 0) self._assign_indice_as_input(node, node_idx) dim_idx = node.kwargs["dim"] self._del_dim(node_idx, dim_idx) @@ -702,21 +693,20 @@ class TraceIndice(object): if (view_dict["idx_to"] == idx_from and view_dict["dim_to"] == dim_from and view_dict["dim_from"] == dim_to): # inheirt indice from current node - for dim_to_i in dim_to: - for dim_from_i in dim_from: - self._inherit_indice(origin_node, dim_from_i, node, dim_to_i, init=False) + if len_diff == 1: + if origin_shape[dim_from[0]] == 1: + self._inherit_indice(origin_node, dim_from[1], node, dim_to[0], init=False) + elif origin_shape[dim_from[1]] == 1: + self._inherit_indice(origin_node, dim_from[0], node, dim_to[0], init=False) + elif len_diff == -1: + if target_shape[dim_to[0]] == 1: + self._inherit_indice(origin_node, dim_from[0], node, dim_to[1], init=False) + elif target_shape[dim_to[1]] == 1: + self._inherit_indice(origin_node, dim_from[0], node, dim_to[0], init=False) # inherid indice from input node of last view for dim_to_i in dim_to: self._inherit_indice(view_node.args[0], dim_to_i, node, dim_to_i, init=False) - # inherit computation - compute_log = self._find_compute_trace_from_node(origin_node) - for i in dim_from: - if origin_trace[i] in compute_log: - for j in dim_to: - self._mark_computation(node, node_idx, [j]) - break - # log view, not used now view_dict = { "idx_from": [origin_trace[i] for i in dim_from], @@ -742,7 +732,7 @@ class TraceIndice(object): active_nodes = self.active_node_list[trace_range[0]:trace_range[1] + 1] active_nodes = set(flat_list(active_nodes)) - active_nodes = [find_idx_by_name(i, self.node_list) for i in active_nodes] + active_nodes = [self.node_mgr.find_node_idx_by_name(i) for i in active_nodes] for i in range(trace_range[0], trace_range[1] + 1): trace = self.indice_trace_list[i] # clear compute @@ -758,7 +748,7 @@ class TraceIndice(object): dim_source.pop(k) def trace_indice(self) -> None: - for idx, node in enumerate(self.node_list): + for idx, node in enumerate(self.node_mgr.get_node_list()): node_name = get_node_name(node) if node.op == "placeholder": self._assign_all_indice(node, idx) diff --git a/colossalai/autochunk/utils.py b/colossalai/autochunk/utils.py index de081b41c..c6bbc219e 100644 --- a/colossalai/autochunk/utils.py +++ b/colossalai/autochunk/utils.py @@ -9,6 +9,59 @@ NON_COMPUTE_NAME = ["getattr", "eq", "_assert_is_none", "_assert", "finfo", "siz logger = get_dist_logger() +class NodeMgr(object): + + def __init__(self, gm) -> None: + self._node_list = list(gm.graph.nodes) + self._node_dict = {} + self._set_node_dict() + + def _set_node_dict(self) -> None: + """ + create a dict {node_name: node_idx} + """ + self._node_dict.clear() + for idx, node in enumerate(self._node_list): + self._node_dict[node.name] = idx + + def find_node_idx(self, node: Node) -> int: + """ + find node's index + """ + return self._node_dict[node.name] + + def find_node_idx_by_name(self, node_name: str) -> int: + """ + find node's index + """ + return self._node_dict[node_name] + + def get_node_by_idx(self, idx: int) -> Node: + """ + get a node by index + """ + return self._node_list[idx] + + def get_node_slice_by_idx(self, start: int, end: int) -> List[Node]: + """ + get a slice of node by index + """ + return self._node_list[start:end] + + def get_node_list(self) -> List: + """ + get full node list + """ + return self._node_list + + def update_node_list(self, node_list: List) -> None: + """ + update node list, reset node dict + """ + self._node_list = node_list + self._set_node_dict() + + def get_logger() -> Any: return logger @@ -42,6 +95,8 @@ def is_non_compute_node(node: Node) -> bool: if any(i == node.op for i in NON_COMPUTE_OP) or any(i == get_node_name(node) for i in NON_COMPUTE_NAME): return True if "getitem" in node.name: + if get_node_shape(node) is not None: + return False node_args = flat_list(node.args[1:]) for node_arg in node_args: if any(i == str(node_arg) for i in ["None", "Ellipsis"]): @@ -53,6 +108,8 @@ def is_non_compute_node(node: Node) -> bool: def get_node_shape(node: Node) -> List: + if get_node_name(node) == "split": + return node.meta["tensor_meta"][0].shape if hasattr(node.meta["tensor_meta"], "shape"): return node.meta["tensor_meta"].shape return None @@ -78,7 +135,7 @@ def is_non_compute_node_except_placeholder_output(node: Node) -> bool: return is_non_compute_node_except_placeholder(node) -def find_idx_by_name(name: str, nodes_list: List) -> int: +def find_node_idx(name: str, nodes_list: List) -> int: for idx, node in enumerate(nodes_list): if node.name == name: return idx @@ -162,3 +219,28 @@ def get_node_name(node: Node) -> str: else: break return node_name + + +def find_tensor_node(node_list: List[Node]) -> List[Node]: + """ + find tensor nodes from a node list + """ + out = [] + for node in node_list: + if get_node_shape(node) is not None: + out.append(node) + return out + + +def find_tensor_shape_node(node_list: List[Node]) -> List[Node]: + """ + find tensor and shape nodes from a node list + """ + out = [] + for node in node_list: + if get_node_shape(node) is not None: + out.append(node) + elif len(node.meta['fwd_out']) > 0 and isinstance(node.meta['fwd_out'], list) and isinstance( + node.meta['fwd_out'][0], int): + out.append(node) + return out diff --git a/tests/test_autochunk/test_alphafold/test_alphafold_utils.py b/tests/test_autochunk/test_alphafold/test_alphafold_utils.py index b05191d2b..cb250d640 100644 --- a/tests/test_autochunk/test_alphafold/test_alphafold_utils.py +++ b/tests/test_autochunk/test_alphafold/test_alphafold_utils.py @@ -23,6 +23,7 @@ def assert_codegen_run( concrete_args: List = None, max_memory: int = None, print_mem: bool = False, + print_est_mem: bool = False, print_progress: bool = False, print_code: bool = False, ) -> List[Dict]: @@ -41,7 +42,7 @@ def assert_codegen_run( codegen = AutoChunkCodeGen( meta_graph, max_memory=max_memory, - print_mem=print_mem, + print_mem=print_est_mem, print_progress=print_progress, ) chunks = codegen.chunk_infos @@ -61,13 +62,20 @@ def assert_codegen_run( code = graph.python_code("self").src if print_code: print(code) - assert "chunk_result = None; chunk_size = None;" in code + assert "chunk_size = None; " in code # assert result inputs = [i[1] for i in meta_args] + [i[1] for i in concrete_args] + inputs = [i.cuda() if isinstance(i, torch.Tensor) else i for i in inputs] model.cuda() with torch.no_grad(): - out_gm = gm(*inputs) + if print_mem: + torch.cuda.reset_peak_memory_stats() + now_mem = torch.cuda.memory_allocated() / 1024**2 + out_gm = gm(*[i.clone() if isinstance(i, torch.Tensor) else i for i in inputs]) + if print_mem: + new_max_mem = torch.cuda.max_memory_allocated() / 1024**2 + print("mem: %.2fMB" % (new_max_mem - now_mem)) out_model = model(*inputs) out_gm = flat_list(out_gm) out_model = flat_list(out_model) @@ -85,9 +93,10 @@ def run_test( max_memory: int, get_model: Any, get_data: Any, - print_code: bool, - print_mem: bool, - print_progress: bool, + print_code: bool = False, + print_mem: bool = False, + print_est_mem: bool = False, + print_progress: bool = False, get_chunk_target: Any = None, ) -> None: # launch colossalai @@ -110,6 +119,7 @@ def run_test( max_memory=max_memory, print_code=print_code, print_mem=print_mem, + print_est_mem=print_est_mem, print_progress=print_progress, ) diff --git a/tests/test_autochunk/test_alphafold/test_evoformer_block.py b/tests/test_autochunk/test_alphafold/test_evoformer_block.py index 787067daa..99a54fe18 100644 --- a/tests/test_autochunk/test_alphafold/test_evoformer_block.py +++ b/tests/test_autochunk/test_alphafold/test_evoformer_block.py @@ -55,9 +55,10 @@ def get_data(msa_len: int, pair_len: int) -> Tuple[List, List]: def get_chunk_target() -> Dict: return { - None: [(118, 123), (219, 237), (264, 289), (302, 309), (97, 104), (144, 152), (185, 193), (241, 242), (21, 46)], - 20: [(118, 123), (230, 237), (275, 282), (305, 306), (100, 101), (32, 39), (73, 79)], - 24: [(118, 123)], + None: [(120, 123), (222, 237), (269, 289), (305, 311), (100, 105), (146, 152), (187, 193), (241, 242), + (25, 50)], + 20: [(120, 123), (232, 237), (277, 282), (305, 306), (100, 101), (34, 39)], + 24: [(120, 123)], } @@ -75,9 +76,6 @@ def test_evoformer_block(data_args, max_memory): get_model=get_model, get_data=get_data, get_chunk_target=get_chunk_target, - print_code=False, - print_mem=False, - print_progress=False, ) mp.spawn(run_func, nprocs=1) @@ -86,10 +84,12 @@ if __name__ == "__main__": run_test( rank=0, data_args=(32, 64), - max_memory=20, + max_memory=24, get_model=get_model, get_data=get_data, + get_chunk_target=get_chunk_target, print_code=False, print_mem=False, + print_est_mem=False, print_progress=False, ) diff --git a/tests/test_autochunk/test_alphafold/test_evoformer_stack.py b/tests/test_autochunk/test_alphafold/test_evoformer_stack.py index 45d8e7ac8..06aba0799 100644 --- a/tests/test_autochunk/test_alphafold/test_evoformer_stack.py +++ b/tests/test_autochunk/test_alphafold/test_evoformer_stack.py @@ -70,9 +70,6 @@ def test_evoformer_stack(data_args, max_memory): max_memory=max_memory, get_model=get_model, get_data=get_data, - print_code=False, - print_mem=False, - print_progress=False, ) mp.spawn(run_func, nprocs=1) @@ -81,7 +78,7 @@ if __name__ == "__main__": run_test( rank=0, data_args=(32, 64), - max_memory=20, + max_memory=None, get_model=get_model, get_data=get_data, print_code=False, diff --git a/tests/test_autochunk/test_alphafold/test_extramsa_block.py b/tests/test_autochunk/test_alphafold/test_extramsa_block.py index a2b72ed1a..1b0273a16 100644 --- a/tests/test_autochunk/test_alphafold/test_extramsa_block.py +++ b/tests/test_autochunk/test_alphafold/test_extramsa_block.py @@ -55,10 +55,10 @@ def get_data(msa_len: int, pair_len: int) -> Tuple[List, List]: def get_chunk_target() -> Dict: return { - None: [(126, 131), (227, 245), (272, 297), (310, 317), (105, 112), (152, 160), (193, 201), (249, 250), - (33, 46)], - 20: [(126, 131), (238, 245), (283, 290), (313, 314), (108, 109), (35, 46)], - 24: [(126, 131)], + None: [(128, 131), (230, 245), (277, 297), (313, 319), (108, 113), (154, 160), (195, 201), (249, 250), + (36, 46)], + 20: [(128, 131), (240, 245), (285, 290), (313, 314), (108, 109), (41, 46)], + 24: [(128, 131)], } @@ -75,9 +75,7 @@ def test_extramsa_block(data_args, max_memory): max_memory=max_memory, get_model=get_model, get_data=get_data, - print_code=False, - print_mem=False, - print_progress=False, + get_chunk_target=get_chunk_target, ) mp.spawn(run_func, nprocs=1) @@ -86,7 +84,7 @@ if __name__ == "__main__": run_test( rank=0, data_args=(32, 64), - max_memory=20, + max_memory=None, get_model=get_model, get_data=get_data, get_chunk_target=get_chunk_target, diff --git a/tests/test_autochunk/test_transformer/test_autochunk_gpt.py b/tests/test_autochunk/test_transformer/test_autochunk_gpt.py index 0ba8f89c2..256df8bbb 100644 --- a/tests/test_autochunk/test_transformer/test_autochunk_gpt.py +++ b/tests/test_autochunk/test_transformer/test_autochunk_gpt.py @@ -17,8 +17,8 @@ from test_transformer_utils import run_test from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE -BATCH_SIZE = 2 -SEQ_LENGTH = 256 +BATCH_SIZE = 1 +SEQ_LENGTH = 512 def get_data(shape: tuple) -> Tuple[List, List]: @@ -37,17 +37,14 @@ def get_data(shape: tuple) -> Tuple[List, List]: ) @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("shape", [(BATCH_SIZE, SEQ_LENGTH)]) -@pytest.mark.parametrize("max_memory", [None, 4.5, 5]) -def test_gpt(model, shape, max_memory): +@pytest.mark.parametrize("max_memory", [None, 6, 8]) +def test_autochunk_gpt(model, shape, max_memory): run_func = partial( run_test, data=get_data(shape), max_memory=max_memory, model=model, config=GPT2Config(n_embd=96, n_position=shape[1], n_layer=2, n_head=4), - print_code=False, - print_mem=False, - print_progress=False, ) mp.spawn(run_func, nprocs=1) @@ -59,7 +56,8 @@ if __name__ == "__main__": max_memory=None, model=GPT2Model, config=GPT2Config(n_embd=96, n_position=SEQ_LENGTH, n_layer=2, n_head=4), - print_code=True, - print_mem=True, + print_code=False, + print_est_mem=False, + print_mem=False, print_progress=False, ) diff --git a/tests/test_autochunk/test_transformer/test_transformer_utils.py b/tests/test_autochunk/test_transformer/test_transformer_utils.py index d33fc04c5..cc26168c7 100644 --- a/tests/test_autochunk/test_transformer/test_transformer_utils.py +++ b/tests/test_autochunk/test_transformer/test_transformer_utils.py @@ -20,6 +20,7 @@ def assert_codegen_run( model: Any, data: tuple, max_memory: int = None, + print_est_mem: bool = False, print_mem: bool = False, print_progress: bool = False, print_code: bool = False, @@ -41,7 +42,7 @@ def assert_codegen_run( codegen = AutoChunkCodeGen( meta_graph, max_memory=max_memory, - print_mem=print_mem, + print_mem=print_est_mem, print_progress=print_progress, ) chunks = codegen.chunk_infos @@ -61,7 +62,7 @@ def assert_codegen_run( code = graph.python_code("self").src if print_code: print(code) - assert "chunk_result = None; chunk_size = None;" in code + assert "chunk_size = None; " in code # assert result inputs = [meta_args[i] if i in meta_args else concrete_args[i] for i in sequence] @@ -69,26 +70,44 @@ def assert_codegen_run( model.cuda().eval() gm.eval() with torch.no_grad(): - out_gm = gm(*inputs) + if print_mem: + torch.cuda.reset_peak_memory_stats() + now_mem = torch.cuda.memory_allocated() / 1024**2 + out_gm = gm(*[i.clone() if isinstance(i, torch.Tensor) else i for i in inputs]) + if print_mem: + new_max_mem = torch.cuda.max_memory_allocated() / 1024**2 + print("mem: %.2fMB" % (new_max_mem - now_mem)) out_model = model(*inputs) - for k in out_model.keys(): - if torch.is_tensor(out_gm[k]): - assert torch.equal( - out_model[k], out_gm[k] - ), f'{model.__class__.__name__} has incorrect output {k}, expect {out_model[k]}, but got {out_gm[k]}' - + assert_allclose(out_model, out_gm) return chunks +def assert_allclose(out_model: Any, out_gm: Any) -> None: + """ + assert allclose for out + """ + if isinstance(out_model, torch.Tensor): + assert torch.allclose(out_model, out_gm, + atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean( + torch.abs(out_model - out_gm)) + elif isinstance(out_model, dict): + for k in out_model.keys(): + assert_allclose(out_model[k], out_gm[k]) + elif isinstance(out_model, tuple) or isinstance(out_model, list) or isinstance(out_model, set): + for i, j in zip(out_model, out_gm): + assert_allclose(i, j) + + def run_test( rank: int, model: Any, config: Any, data: tuple, max_memory: int, - print_code: bool, - print_mem: bool, - print_progress: bool, + print_code: bool = False, + print_est_mem: bool = False, + print_mem: bool = False, + print_progress: bool = False, get_chunk_target: Any = None, ) -> None: model = model(config=config) @@ -108,6 +127,7 @@ def run_test( data=data, max_memory=max_memory, print_code=print_code, + print_est_mem=print_est_mem, print_mem=print_mem, print_progress=print_progress, ) @@ -119,5 +139,3 @@ def run_test( str(chunk_found), str(chunk_target), ) - - gpc.destroy()