diff --git a/colossalai/autochunk/autochunk_codegen.py b/colossalai/autochunk/autochunk_codegen.py index 82937db9f..90bde8730 100644 --- a/colossalai/autochunk/autochunk_codegen.py +++ b/colossalai/autochunk/autochunk_codegen.py @@ -9,18 +9,7 @@ from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABL AUTOCHUNK_AVAILABLE = CODEGEN_AVAILABLE and is_compatible_with_meta() if AUTOCHUNK_AVAILABLE: - from torch.fx.graph import ( - CodeGen, - PythonCode, - _custom_builtins, - _CustomBuiltin, - _format_target, - _is_from_torch, - _Namespace, - _origin_type_map, - inplace_methods, - magic_methods, - ) + from torch.fx.graph import CodeGen, PythonCode, _custom_builtins, _CustomBuiltin, _format_target, _is_from_torch, _Namespace, _origin_type_map, inplace_methods, magic_methods from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg @@ -143,7 +132,7 @@ def _replace_reshape_size(context: str, node_name: str, reshape_size_dict: Dict) return context -def _replace_ones_like( +def _replace_new_tensor_like_shape( search_chunk: SearchChunk, chunk_infos: List[Dict], region_idx: int, @@ -154,7 +143,7 @@ def _replace_ones_like( """ add chunk slice for new tensor op such as ones like """ - if "ones_like" in node.name: + if get_node_name(node) in ["ones_like", "zeros_like", "empty_like"]: meta_node = search_chunk.node_mgr.get_node_by_idx(node_idx) chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node]["chunk_dim"] if get_node_shape(meta_node)[chunk_dim] != 1: @@ -166,6 +155,33 @@ def _replace_ones_like( return body +def _replace_new_tensor_shape( + search_chunk: SearchChunk, + chunk_infos: List[Dict], + region_idx: int, + node_idx: int, + node: Node, + body: List[str], +) -> List[str]: + """ + add chunk slice for new tensor op such as ones + """ + if get_node_name(node) in ["ones", "zeros", "empty"]: + meta_node = search_chunk.node_mgr.get_node_by_idx(node_idx) + chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node]["chunk_dim"] + if chunk_dim is None: + return + if get_node_shape(meta_node)[chunk_dim] == 1: + return + origin_shape = str(node.args) + new_shape = list(node.args) + new_shape[chunk_dim] = "min(chunk_size, %d - chunk_idx)" % get_node_shape(meta_node)[chunk_dim] + new_shape = str(new_shape) + new_shape = new_shape.replace("'", "") + body[-1] = _replace_name(body[-1], origin_shape[1:-1], new_shape[1:-1]) + return body + + def _add_node_slice( chunk_nodes: List[Node], region_idx: int, @@ -265,8 +281,10 @@ def emit_code_with_chunk( body = _add_node_slice(chunk_inputs, region_idx, chunk_inputs_dim, node_idx, body, node) # replace output var with chunk var body = _add_node_slice(chunk_outputs, region_idx, chunk_outputs_dim, node_idx, body, node) - # ones like - body = _replace_ones_like(search_chunk, chunk_infos, region_idx, node_idx, node, body) + # new tensor like + body = _replace_new_tensor_like_shape(search_chunk, chunk_infos, region_idx, node_idx, node, body) + # new tensor + body = _replace_new_tensor_shape(search_chunk, chunk_infos, region_idx, node_idx, node, body) # reassgin reshape size body[-1] = _replace_reshape_size(body[-1], node.name, chunk_infos[region_idx]["reshape_size"]) body[-1] = " " + body[-1] diff --git a/colossalai/autochunk/search_chunk.py b/colossalai/autochunk/search_chunk.py index 0278e03f7..eb9949095 100644 --- a/colossalai/autochunk/search_chunk.py +++ b/colossalai/autochunk/search_chunk.py @@ -8,14 +8,7 @@ from .reorder_graph import ReorderGraph from .select_chunk import SelectChunk from .trace_flow import TraceFlow from .trace_indice import TraceIndice -from .utils import ( - NodeMgr, - find_chunk_compute_input_and_output_nodes, - get_logger, - get_node_shape, - is_non_compute_node, - is_non_compute_node_except_placeholder, -) +from .utils import NodeMgr, get_logger, get_node_shape, is_non_compute_node, is_non_compute_node_except_placeholder class SearchChunk(object): @@ -75,8 +68,8 @@ class SearchChunk(object): max_chunk_region_list = [] while True: max_chunk_region = self._search_max_chunk_region(active_nodes, cur_node_idx) - cur_node_idx = max_chunk_region[1] - if cur_node_idx == len(active_nodes) - 1: + cur_node_idx = max_chunk_region[1] + 1 + if cur_node_idx >= len(active_nodes) - 1: break max_chunk_region_list.append(max_chunk_region) @@ -135,6 +128,7 @@ class SearchChunk(object): min_active_node_num = min(active_node_num[free_var_num:]) threshold = max(free_var_num, min_active_node_num) + # normal search # from peak_node to free_var inside_flag = False chunk_region_start = free_var_num @@ -144,7 +138,6 @@ class SearchChunk(object): if inside_flag and active_node_num[i] > threshold: chunk_region_start = i + 1 break - # from peak_node to len-2 inside_flag = False chunk_region_end = len(active_node) - 1 @@ -155,6 +148,22 @@ class SearchChunk(object): chunk_region_end = i break + # if normal search fails, use approximate search + if (chunk_region_end - chunk_region_start) > 250: + window_size = 100 + # search min for start + min_num = 1e3 + for i in range(max(peak_node_idx - window_size, 0), peak_node_idx + 1): + if active_node_num[i] < min_num: + min_num = active_node_num[i] + chunk_region_start = i + # search min for end + min_num = 1e3 + for i in range(min(peak_node_idx + window_size, len(active_node_num) - 1), peak_node_idx - 1, -1): + if active_node_num[i] < min_num: + min_num = active_node_num[i] + chunk_region_end = i + # avoid chunk regions overlap if chunk_regions is not None: for i in chunk_regions: @@ -271,12 +280,6 @@ class SearchChunk(object): best_chunk_region = self.reorder_graph.reorder_all(best_chunk_region) return best_chunk_region - def _stop_search(self, init_mem_peak, mem_peak): - sorted_init_mem_peak = sorted(init_mem_peak) - if max(mem_peak) < sorted_init_mem_peak[int(len(sorted_init_mem_peak) * 0.5)]: - return True - return False - def search_region(self) -> Dict: """ Search all chunk regions: @@ -291,11 +294,7 @@ class SearchChunk(object): get_logger().info("AutoChunk start searching chunk regions") chunk_infos = [] - ( - init_mem_peak, - _, - active_node, - ) = self.estimate_memory.estimate_chunk_inference_mem(self.node_mgr.get_node_list()) + init_mem_peak, _, active_node = self.estimate_memory.estimate_chunk_inference_mem(self.node_mgr.get_node_list()) mem_peak = init_mem_peak while True: @@ -304,18 +303,13 @@ class SearchChunk(object): break chunk_infos.append(chunk_info) - ( - mem_peak, - _, - active_node, - ) = self.estimate_memory.estimate_chunk_inference_mem(self.node_mgr.get_node_list(), chunk_infos) + mem_peak, _, active_node = self.estimate_memory.estimate_chunk_inference_mem( + self.node_mgr.get_node_list(), chunk_infos) if self.print_progress: get_logger().info("AutoChunk find chunk region %d = (%d, %d)" % (len(chunk_infos), chunk_info["region"][0], chunk_info["region"][1])) - if self._stop_search(init_mem_peak, mem_peak): - break if self.print_mem: self.print_mem = False self.estimate_memory.estimate_chunk_inference_mem(self.node_mgr.get_node_list(), diff --git a/colossalai/autochunk/trace_flow.py b/colossalai/autochunk/trace_flow.py index 11dbb266d..16815215f 100644 --- a/colossalai/autochunk/trace_flow.py +++ b/colossalai/autochunk/trace_flow.py @@ -100,6 +100,16 @@ class TraceFlow(object): if not (start_idx <= arg_idx < end_idx): return True + # get fix dim + arg_fix_dim = [] + if cur_node_dim is not None: + for i in cur_node_fix_dim: + fix_dim_source = cur_node_source[i] + if arg_idx in fix_dim_source: + arg_fix_dim.append(fix_dim_source[arg_idx][0]) + if arg_node in all_node_info: + arg_fix_dim = list(set(all_node_info[arg_node]["fix_dim"] + arg_fix_dim)) + # find arg dim if cur_node_dim is not None: # dim is computed @@ -109,6 +119,9 @@ class TraceFlow(object): arg_dim = None else: arg_dim = cur_node_source[cur_node_dim][arg_idx][0] + # chunk dim cannot be in fix dims + if arg_dim in arg_fix_dim: + return False # chunk dim should be None if shape size is 1 if get_node_shape(arg_node)[arg_dim] == 1: arg_dim = None @@ -120,19 +133,16 @@ class TraceFlow(object): else: arg_dim = None - # get fix dim - arg_fix_dim = [] - if cur_node_dim is not None: - for i in cur_node_fix_dim: - fix_dim_source = cur_node_source[i] - if arg_idx in fix_dim_source: - arg_fix_dim.append(fix_dim_source[arg_idx][0]) + # add arg rest dim as fix dim + arg_fix_dim = list(range(len(get_node_shape(arg_node)))) + if arg_dim is not None: + arg_fix_dim.remove(arg_dim) # if already in node_info, arg dim must be same if arg_node in all_node_info: if all_node_info[arg_node]["chunk_dim"] != arg_dim: return False - all_node_info[arg_node]["fix_dim"] = list(set(all_node_info[arg_node]["fix_dim"] + arg_fix_dim)) + all_node_info[arg_node]["fix_dim"] = arg_fix_dim # else add it to list else: all_node_info[arg_node] = {"chunk_dim": arg_dim, "fix_dim": arg_fix_dim} @@ -164,6 +174,8 @@ class TraceFlow(object): continue if is_non_compute_node(arg): continue + if get_node_shape(arg) is None: + continue arg_list.append(arg) flow_flag = self._assgin_single_node_flow( arg, @@ -180,29 +192,6 @@ class TraceFlow(object): if flow_flag == False: return None - if len(arg_list) >= 2: - # need to mark fix dim - if any(i == get_node_name(cur_node) for i in ["add", "mul", "truediv", "sub", "where"]): - for arg in arg_list: - if get_node_shape(arg) is None: - continue - if not (start_idx <= self.node_mgr.find_node_idx(arg) < end_idx): - continue - arg_chunk_dim = all_node_info[arg]["chunk_dim"] - arg_fix_dim = all_node_info[arg]["fix_dim"] - arg_shape = get_node_shape(arg) - # add all dim as fix dim except chunk dim - for i, shape in enumerate(arg_shape): - if shape != 1 and i != cur_node_chunk_dim: - if i == arg_chunk_dim: - return None - if i not in arg_fix_dim: - arg_fix_dim.append(i) - elif any(i == get_node_name(cur_node) - for i in ["einsum", "matmul", "view", "to", "getitem", "tensor", "type"]): - pass - else: - raise NotImplementedError() cur_node_list = next_node_list return all_node_info diff --git a/colossalai/autochunk/trace_indice.py b/colossalai/autochunk/trace_indice.py index b591fa764..1e41073d7 100644 --- a/colossalai/autochunk/trace_indice.py +++ b/colossalai/autochunk/trace_indice.py @@ -150,7 +150,7 @@ class TraceIndice(object): for i in range(len(node_from_indice)): self._inherit_indice(node_from, i, node_to, i, init=True) - def _inherit_more_indice_from_node(self, node_from: Node, node_to: Node, exclude: List = None) -> None: + def _inherit_more_indice_from_node_with_exclude(self, node_from: Node, node_to: Node, exclude: List = None) -> None: """ inheirt indice from node without init """ @@ -308,14 +308,14 @@ class TraceIndice(object): node (node) node_idx (int) """ - if len(node.args) == 2: - _, weight = node.args - else: - _, weight, _ = node.args - self._assign_indice_as_input(node, node_idx) - self._inherit_indice(weight, 1, node, -1) + if len(node.args) >= 2: + weight = node.args[1] + self._inherit_indice(weight, 1, node, -1) + else: + self._del_dim(node_idx, -1) + self._add_dim(node_idx, -1) self._mark_computation(node, node_idx, [-1]) def _assign_addmm_indice(self, node: Node, node_idx: int) -> None: @@ -327,13 +327,35 @@ class TraceIndice(object): node_idx (int) """ bias, input_node, weight = node.args - + assert len(get_node_shape(bias)) == 1 and len(get_node_shape(weight)) == 2 self._assign_indice_as_input(node, node_idx, input_node) self._inherit_indice(weight, 1, node, -1) - self._inherit_indice(bias, -1, node, -1) + self._inherit_more_indice_from_node_with_exclude(bias, node) self._mark_computation(node, node_idx, [-1]) + def _assign_baddbmm_indice(self, node: Node, node_idx: int) -> None: + """ + Assign indice for baddbmm(batch add and batch matmul) op. + add, matmul_left, matmul_right = args + out = add + (matmul_left x matmul_right) + + Args: + node (node) + node_idx (int) + """ + add, matmul_left, matmul_right = node.args + + assert get_node_shape(add) == get_node_shape(node) + assert len(get_node_shape(matmul_left)) == len(get_node_shape(matmul_right)) + self._assign_indice_as_input(node, node_idx, matmul_left) + # matmul + self._inherit_indice(matmul_right, -1, node, -1) + self._inherit_more_indice_from_node_with_exclude(matmul_right, node, [-2, -1]) + self._mark_computation(node, node_idx, [-1]) + # add + self._inherit_more_indice_from_node_with_exclude(add, node) + def _assign_matmul_indice(self, node: Node, node_idx: int) -> None: """ Assign indice for matmul op. @@ -349,11 +371,53 @@ class TraceIndice(object): assert len(get_node_shape(matmul_left)) == len(get_node_shape(matmul_right)) self._assign_indice_as_input(node, node_idx, matmul_left) - self._inherit_indice(matmul_right, -1, node, -1) - self._inherit_more_indice_from_node(matmul_right, node, [-1, -2]) + self._inherit_indice(matmul_right, -1, node, -1) + self._inherit_more_indice_from_node_with_exclude(matmul_right, node, [-1, -2]) self._mark_computation(node, node_idx, [-1]) + def _assign_conv2d_indice(self, node: Node, node_idx: int) -> None: + """ + Assign indice for conv2d op. + + Args: + node (node) + node_idx (int) + """ + # get conv module + node_targets = node.target.split(".") + conv_module = node.graph.owning_module + for i in node_targets: + conv_module = getattr(conv_module, i) + assert conv_module.dilation == (1, 1), "dilation for conv2d not implemented" + + # get conv input + assert len(node.args) == 1 + input_node = node.args[0] + assert len(get_node_shape(input_node)) == 4 + + # assgin index + self._assign_indice_as_input(node, node_idx, input_node) + self._del_dim(node_idx, 1) + self._add_dim(node_idx, 1) + self._mark_computation(node, node_idx, [1, 2, 3]) + + def _assign_interpolate_indice(self, node: Node, node_idx: int) -> None: + """ + Assign indice for interpolate op. + + Args: + node (node) + node_idx (int) + """ + # get conv input + assert node.kwargs['size'] is None + assert len(get_node_shape(node)) == 4 + + # assgin index + self._assign_indice_as_input(node, node_idx) + self._mark_computation(node, node_idx, [-1, -2]) + def _assign_layernorm_indice(self, node, idx): """ Assign indice for layernorm op. @@ -367,6 +431,18 @@ class TraceIndice(object): self._assign_indice_as_input(node, idx) self._mark_computation(node, idx, [-1]) + def _assign_groupnorm_indice(self, node, idx): + """ + Assign indice for groupnorm op. + + Args: + node (node) + node_idx (int) + """ + assert len(get_node_shape(node)) == 4 + self._assign_indice_as_input(node, idx) + self._mark_computation(node, idx, [-1, -2, -3]) + def _assign_elementwise_indice(self, node, idx): """ Assign indice for element-wise op (eg. relu sigmoid add mul). @@ -382,13 +458,13 @@ class TraceIndice(object): for node_in in node.args: if type(node_in) == type(node): nodes_in.append(node_in) - self._inherit_more_indice_from_node(node_in, node) + self._inherit_more_indice_from_node_with_exclude(node_in, node) def _assgin_no_change_indice(self, node, idx): self._assign_indice_as_input(node, idx) for node_in in node.args: if type(node_in) == type(node): - self._inherit_more_indice_from_node(node_in, node) + self._inherit_more_indice_from_node_with_exclude(node_in, node) def _assign_einsum_indice(self, node, idx): """ @@ -469,17 +545,6 @@ class TraceIndice(object): dim_idx = list(range(len(get_node_shape(node))))[dim_idx] self._add_dim(node_idx, dim_idx) - def _assign_ones_like_indice(self, node: Node, node_idx: int) -> None: - """ - Assign indice for oneslike op. - 1. assign new indice for all dim - - Args: - node (node) - node_idx (int) - """ - self._assign_all_indice(node, node_idx) - def _assign_cat_indice(self, node: Node, node_idx: int) -> None: """ Assign indice for cat op. @@ -491,7 +556,7 @@ class TraceIndice(object): nodes_in = flat_list(node.args[0]) self._assign_indice_as_input(node, node_idx, input_node=nodes_in[0]) for n in nodes_in[1:]: - self._inherit_more_indice_from_node(n, node) + self._inherit_more_indice_from_node_with_exclude(n, node) cat_dim = node.kwargs["dim"] self._del_dim(node_idx, cat_dim) self._add_dim(node_idx, cat_dim) @@ -508,33 +573,10 @@ class TraceIndice(object): self._add_dim(node_idx, 0) self._assign_indice_as_input(node, node_idx, input_node=nodes_in[0]) for n in nodes_in[1:]: - self._inherit_more_indice_from_node(n, node) + self._inherit_more_indice_from_node_with_exclude(n, node) cat_dim = node.kwargs["dim"] self._del_dim(node_idx, cat_dim) - def _assign_arange_indice(self, node: Node, node_idx: int) -> None: - """ - Assign indice for arange op. - - Args: - node (node) - node_idx (int) - """ - self._assign_all_indice(node, node_idx) - - def _assign_tensor_indice(self, node: Node, node_idx: int) -> None: - """ - Assign indice for tensor op. - - Args: - node (node) - node_idx (int) - """ - if len(get_node_shape(node)) == 0: - return - else: - raise NotImplementedError() - def _assign_embedding_indice(self, node: Node, node_idx: int) -> None: """ Assign indice for embedding op. @@ -763,10 +805,10 @@ class TraceIndice(object): self._assign_unsqueeze_indice(node, idx) elif "split" == node_name: self._assign_split_indice(node, idx) - elif any(i == node_name for i in ["to", "contiguous", "clone", "type"]): + elif any(i == node_name for i in ["to", "contiguous", "clone", "type", "float"]): self._assgin_no_change_indice(node, idx) elif "new_ones" == node_name: - self._assign_ones_like_indice(node, idx) + self._assign_all_indice(node, idx) elif any(i == node_name for i in ["size"]): continue else: @@ -776,25 +818,15 @@ class TraceIndice(object): self._assign_linear_indice(node, idx) elif "cat" == node_name: self._assign_cat_indice(node, idx) - elif "matmul" == node_name: + elif any(n == node_name for n in ["matmul", "bmm"]): self._assign_matmul_indice(node, idx) elif "softmax" == node_name: self._assign_softmax_indice(node, idx) elif any(n == node_name for n in [ - "mul", - "add", - "sigmoid", - "relu", - "sub", - "truediv", - "pow", - "dropout", - "where", - "tanh", + "mul", "add", "sigmoid", "relu", "sub", "truediv", "pow", "dropout", "where", "tanh", "exp", + "sin", "cos" ]): self._assign_elementwise_indice(node, idx) - elif "ones_like" == node_name: - self._assign_ones_like_indice(node, idx) elif "einsum" == node_name: self._assign_einsum_indice(node, idx) elif "sum" == node_name: @@ -805,10 +837,12 @@ class TraceIndice(object): self._assign_getitem_indice(node, idx) elif "addmm" == node_name: self._assign_addmm_indice(node, idx) - elif "arange" == node_name: - self._assign_arange_indice(node, idx) - elif "tensor" == node_name: - self._assign_arange_indice(node, idx) + elif "baddbmm" == node_name: + self._assign_baddbmm_indice(node, idx) + elif "interpolate" == node_name: + self._assign_interpolate_indice(node, idx) + elif any(i == node_name for i in ["arange", "ones", "ones_like", "tensor", "empty"]): + self._assign_all_indice(node, idx) elif any(i == node_name for i in ["getattr", "eq", "_assert_is_none", "_assert", "finfo"]): continue else: @@ -817,9 +851,15 @@ class TraceIndice(object): node_name = get_module_node_name(node) if "layernorm" == node_name: self._assign_layernorm_indice(node, idx) + elif "groupnorm" == node_name: + self._assign_groupnorm_indice(node, idx) elif "embedding" == node_name: self._assign_embedding_indice(node, idx) - elif any(n == node_name for n in ["sigmoid", "dropout", "relu"]): + elif "linear" == node_name: + self._assign_linear_indice(node, idx) + elif "conv2d" == node_name: + self._assign_conv2d_indice(node, idx) + elif any(n == node_name for n in ["sigmoid", "dropout", "relu", "silu"]): self._assign_elementwise_indice(node, idx) else: raise NotImplementedError(node_name, "module not implemented yet!") diff --git a/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_diffuser_utils.py b/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_diffuser_utils.py index 0f3d22dc5..529250fe8 100644 --- a/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_diffuser_utils.py +++ b/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_diffuser_utils.py @@ -22,6 +22,7 @@ def assert_codegen_run( concrete_args: List = None, max_memory: int = None, print_mem: bool = False, + print_est_mem: bool = False, print_progress: bool = False, print_code: bool = False, ) -> List[Dict]: @@ -35,13 +36,14 @@ def assert_codegen_run( meta_args={k: v.to(torch.device("meta")) for k, v in meta_args}, concrete_args={k: v for k, v in concrete_args}, ) + model = model.cuda().eval() interp = MetaInfoProp(meta_graph) meta_tensors = [MetaTensor(i[1], fake_device="cuda:0") for i in meta_args] + [i[1] for i in concrete_args] interp.propagate(*meta_tensors) codegen = AutoChunkCodeGen( meta_graph, max_memory=max_memory, - print_mem=print_mem, + print_mem=print_est_mem, print_progress=print_progress, ) chunks = codegen.chunk_infos @@ -61,17 +63,29 @@ def assert_codegen_run( code = graph.python_code("self").src if print_code: print(code) - assert "chunk_result = None; chunk_size = None;" in code + assert "chunk_size = None; " in code # assert result inputs = [i[1] for i in meta_args] + [i[1] for i in concrete_args] + inputs = [i.cuda() if isinstance(i, torch.Tensor) else i for i in inputs] model.cuda().eval() gm.eval() with torch.no_grad(): - out_gm = gm(*inputs) - out_model = model(*inputs) + if print_mem: + torch.cuda.reset_peak_memory_stats() + now_mem_gm = torch.cuda.memory_allocated() / 1024**2 + out_gm = gm(*[i.clone() if isinstance(i, torch.Tensor) else i for i in inputs]) + if print_mem: + max_mem_gm = torch.cuda.max_memory_allocated() / 1024**2 + torch.cuda.reset_peak_memory_stats() + now_mem_ori = torch.cuda.memory_allocated() / 1024**2 + out_model = model(*[i.clone() if isinstance(i, torch.Tensor) else i for i in inputs]) + if print_mem: + max_mem_ori = torch.cuda.max_memory_allocated() / 1024**2 + print("origin mem: %.2fMB, autochunk mem: %.2fMB" % (max_mem_ori - now_mem_ori, max_mem_gm - now_mem_gm)) + assert torch.allclose(out_gm["sample"], out_model["sample"], - atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean( + atol=1e-3), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean( torch.abs(out_gm["sample"] - out_model["sample"])) return chunks @@ -82,9 +96,10 @@ def run_test( model: Any, data: tuple, max_memory: int, - print_code: bool, - print_mem: bool, - print_progress: bool, + print_code: bool = False, + print_mem: bool = False, + print_est_mem: bool = False, + print_progress: bool = False, get_chunk_target: Any = None, ) -> None: # launch colossalai @@ -106,6 +121,7 @@ def run_test( max_memory=max_memory, print_code=print_code, print_mem=print_mem, + print_est_mem=print_est_mem, print_progress=print_progress, ) diff --git a/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_unet.py b/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_unet.py index 9ebe6f393..518c7f451 100644 --- a/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_unet.py +++ b/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_unet.py @@ -17,10 +17,9 @@ from test_autochunk_diffuser_utils import run_test from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE -BATCH_SIZE = 2 -SEQ_LENGTH = 5 -HEIGHT = 224 -WIDTH = 224 +BATCH_SIZE = 1 +HEIGHT = 448 +WIDTH = 448 IN_CHANNELS = 3 LATENTS_SHAPE = (BATCH_SIZE, IN_CHANNELS, HEIGHT // 7, WIDTH // 7) @@ -34,26 +33,19 @@ def get_data(shape: tuple) -> Tuple[List, List]: return meta_args, concrete_args -@pytest.mark.skipif( - True, - reason="not implemented", -) @pytest.mark.skipif( not (AUTOCHUNK_AVAILABLE and HAS_REPO), reason="torch version is lower than 1.12.0", ) @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("shape", [LATENTS_SHAPE]) -@pytest.mark.parametrize("max_memory", [64]) +@pytest.mark.parametrize("max_memory", [None]) def test_evoformer_block(model, shape, max_memory): run_func = partial( run_test, max_memory=max_memory, model=model, data=get_data(shape), - print_code=False, - print_mem=False, - print_progress=False, ) mp.spawn(run_func, nprocs=1) @@ -62,9 +54,10 @@ if __name__ == "__main__": run_test( rank=0, data=get_data(LATENTS_SHAPE), - max_memory=64, + max_memory=None, model=UNet2DModel, print_code=False, print_mem=False, + print_est_mem=False, print_progress=False, )