diff --git a/colossalai/autochunk/autochunk_codegen.py b/colossalai/autochunk/autochunk_codegen.py index 15e15517b..2cbc6c922 100644 --- a/colossalai/autochunk/autochunk_codegen.py +++ b/colossalai/autochunk/autochunk_codegen.py @@ -63,7 +63,7 @@ def _gen_loop_start(chunk_input: List[Node], chunk_output: List[Node], chunk_oup context = "" for i in range(len(chunk_output)): shape_str = str(list(get_node_shape(chunk_output[i]))) - if get_node_name(chunk_output[i]) == "split": + if get_node_name(chunk_output[i]) in ["split", "unbind"]: tensor_str = "torch.empty(%s, dtype=%s.dtype, device=%s.device), " % (shape_str, input_node.name, input_node.name) tensor_str = tensor_str * len(chunk_output[i].meta['tensor_meta']) @@ -205,7 +205,7 @@ def _add_node_slice( if chunk_node.name == node.name or (chunk_node.name in [i.name for i in node.all_input_nodes]): chunk_slice = _gen_chunk_slice_dim(chunk_nodes_dim[region_idx][chunk_node_idx], "chunk_idx", get_node_shape(chunk_node)) - if get_node_name(chunk_node) == "split": + if get_node_name(chunk_node) in ["split", "unbind"]: split_chunk_slice = "" for i in range(len(chunk_node.meta['tensor_meta'])): split_chunk_slice += "%s[%d]%s, " % (chunk_node.name, i, chunk_slice) diff --git a/colossalai/autochunk/trace_indice.py b/colossalai/autochunk/trace_indice.py index 92199b79a..307f4de32 100644 --- a/colossalai/autochunk/trace_indice.py +++ b/colossalai/autochunk/trace_indice.py @@ -74,6 +74,9 @@ class TraceIndice(object): """ add a dim for indice, compute and source """ + # need to remap if dim_idx < 0, e.g. -1 + if dim_idx < 0: + dim_idx = list(range(len(self.indice_trace_list[node_idx]["indice"]) + 1))[dim_idx] self.indice_trace_list[node_idx]["indice"].insert(dim_idx, self._add_indice()) self.indice_trace_list[node_idx]["compute"].insert(dim_idx, []) self.indice_trace_list[node_idx]["source"].insert(dim_idx, {}) @@ -575,6 +578,60 @@ class TraceIndice(object): cat_dim = node.kwargs["dim"] self._del_dim(node_idx, cat_dim) + def _assign_flatten_indice(self, node: Node, node_idx: int) -> None: + """ + Assign indice for flatten op. + + Args: + node (node) + node_idx (int) + """ + nodes_in = node.args[0] + nodes_in_shape = get_node_shape(nodes_in) + flatten_start_dim = node.args[1] + flatten_dim_num = len(nodes_in_shape) - flatten_start_dim - 1 + assert flatten_dim_num > 0 + for _ in range(flatten_dim_num): + self._add_dim(node_idx, 0) + self._assign_indice_as_input(node, node_idx, nodes_in) + for _ in range(flatten_dim_num + 1): + self._del_dim(node_idx, -1) + self._add_dim(node_idx, -1) + + def _assign_expand_indice(self, node: Node, node_idx: int) -> None: + """ + Assign indice for expand op. + + Args: + node (node) + node_idx (int) + """ + expand_shape = node.args[1:] + node_in_shape = get_node_shape(node.args[0]) + assert len(expand_shape) == len(node_in_shape) + self._assign_indice_as_input(node, node_idx) + for i in range(len(node_in_shape)): + if expand_shape[i] == node_in_shape[i] or expand_shape[i] == -1: + continue + elif expand_shape[i] > node_in_shape[i]: + self._del_dim(node_idx, i) + self._add_dim(node_idx, i) + else: + raise RuntimeError() + + def _assign_unbind_indice(self, node: Node, node_idx: int) -> None: + """ + Assign indice for unbind op. + + Args: + node (node) + node_idx (int) + """ + unbind_dim = node.args[1] + self._add_dim(node_idx, unbind_dim) + self._assign_indice_as_input(node, node_idx) + self._del_dim(node_idx, unbind_dim) + def _assign_embedding_indice(self, node: Node, node_idx: int) -> None: """ Assign indice for embedding op. @@ -695,32 +752,39 @@ class TraceIndice(object): shape_idx = target_shape.index(-1) target_shape[shape_idx] = origin_product // target_product - # determine changed dim - len_diff = len(origin_shape) - len(target_shape) - if len_diff == 1: + # find same dim + dim_to_same_dim = [] + dim_from_same_dim = [] + for i in range(len(origin_shape)): + if origin_shape[i] == target_shape[i]: + dim_to_same_dim.append(i) + dim_from_same_dim.append(i) + else: + break + for i in range(-1, -len(origin_shape), -1): + if origin_shape[i] == target_shape[i]: + dim_to_same_dim.append(len(target_shape) + i) + dim_from_same_dim.append(len(origin_shape) + i) + else: + break + + dim_from = list(set(range(len(origin_shape))) - set(dim_from_same_dim)) + dim_to = list(set(range(len(target_shape))) - set(dim_to_same_dim)) + assert len(dim_from) == 1 or len(dim_to) == 1 or len(dim_from) == len(dim_to) + + dim_diff = len(dim_from) - len(dim_to) + if dim_diff > 0: # dim merge - dim_equal = [i == j for i, j in zip(origin_shape[:-1], target_shape)] - dim_to = [dim_equal.index(False)] - dim_from = [dim_equal.index(False), dim_equal.index(False) + 1] - self._add_dim(node_idx, -1) - elif len_diff == -1: + for i in range(dim_diff): + self._add_dim(node_idx, -1) + elif dim_diff < 0: # dim expand - dim_equal = [i == j for i, j in zip(origin_shape, target_shape[:-1])] - dim_from = [dim_equal.index(False)] - dim_to = [dim_equal.index(False), dim_equal.index(False) + 1] - self._del_dim(node_idx, -1) - elif len_diff == 0: - # dim equal - dim_equal = [i == j for i, j in zip(origin_shape, target_shape[:-1])] - dim_from = [] - dim_to = [] - else: - raise NotImplementedError("shape" + str(origin_shape) + "and" + str(target_shape) + "view not implemented") + for i in range(-dim_diff): + self._del_dim(node_idx, -1) # get new indice origin_trace = self._find_indice_trace_from_node(origin_node) self._assign_indice_as_input(node, node_idx, origin_node) - idx_from = [origin_trace[i] for i in dim_from] dim_from.reverse() for i in dim_from: self._del_dim(node_idx, i) @@ -728,36 +792,18 @@ class TraceIndice(object): self._add_dim(node_idx, i) dim_from.reverse() - # search view list - # for view_node, view_dict in self.indice_view_list.items(): - # if (view_dict["idx_to"] == idx_from and view_dict["dim_to"] == dim_from - # and view_dict["dim_from"] == dim_to): - # # inheirt indice from current node - # if len_diff == 1: - # if origin_shape[dim_from[0]] == 1: - # self._inherit_indice(origin_node, dim_from[1], node, dim_to[0], init=False) - # elif origin_shape[dim_from[1]] == 1: - # self._inherit_indice(origin_node, dim_from[0], node, dim_to[0], init=False) - # elif len_diff == -1: - # if target_shape[dim_to[0]] == 1: - # self._inherit_indice(origin_node, dim_from[0], node, dim_to[1], init=False) - # elif target_shape[dim_to[1]] == 1: - # self._inherit_indice(origin_node, dim_from[0], node, dim_to[0], init=False) - # # inherid indice from input node of last view - # for dim_to_i in dim_to: - # self._inherit_indice(view_node.args[0], dim_to_i, node, dim_to_i, init=False) - # inheirt indice from current node - if len_diff == 1: - if origin_shape[dim_from[0]] == 1: - self._inherit_indice(origin_node, dim_from[1], node, dim_to[0], init=False) - elif origin_shape[dim_from[1]] == 1: - self._inherit_indice(origin_node, dim_from[0], node, dim_to[0], init=False) - elif len_diff == -1: - if target_shape[dim_to[0]] == 1: - self._inherit_indice(origin_node, dim_from[0], node, dim_to[1], init=False) - elif target_shape[dim_to[1]] == 1: - self._inherit_indice(origin_node, dim_from[0], node, dim_to[0], init=False) + if len(dim_from) != 0 and len(dim_to) != 0: + if dim_diff == 1: + if origin_shape[dim_from[0]] == 1: + self._inherit_indice(origin_node, dim_from[1], node, dim_to[0], init=False) + elif origin_shape[dim_from[1]] == 1: + self._inherit_indice(origin_node, dim_from[0], node, dim_to[0], init=False) + elif dim_diff == -1: + if target_shape[dim_to[0]] == 1: + self._inherit_indice(origin_node, dim_from[0], node, dim_to[1], init=False) + elif target_shape[dim_to[1]] == 1: + self._inherit_indice(origin_node, dim_from[0], node, dim_to[0], init=False) # log view, not used now view_dict = { @@ -809,6 +855,14 @@ class TraceIndice(object): self._assgin_no_change_indice(node, idx) elif "new_ones" == node_name: self._assign_all_indice(node, idx) + elif "flatten" == node_name: + self._assign_flatten_indice(node, idx) + elif "expand" == node_name: + self._assign_expand_indice(node, idx) + elif "unbind" == node_name: + self._assign_unbind_indice(node, idx) + elif "softmax" == node_name: + self._assign_softmax_indice(node, idx) elif any(i == node_name for i in ["size"]): continue else: @@ -859,7 +913,9 @@ class TraceIndice(object): self._assign_linear_indice(node, idx) elif "conv2d" == node_name: self._assign_conv2d_indice(node, idx) - elif any(n == node_name for n in ["sigmoid", "dropout", "relu", "silu"]): + elif "identity" == node_name: + self._assgin_no_change_indice(node, idx) + elif any(n == node_name for n in ["sigmoid", "dropout", "relu", "silu", "gelu"]): self._assign_elementwise_indice(node, idx) else: raise NotImplementedError(node_name, "module not implemented yet!") diff --git a/colossalai/autochunk/utils.py b/colossalai/autochunk/utils.py index 7c0bc29b5..064baa047 100644 --- a/colossalai/autochunk/utils.py +++ b/colossalai/autochunk/utils.py @@ -109,8 +109,11 @@ def is_non_compute_node(node: Node) -> bool: return False -def get_node_shape(node: Node) -> List: - if get_node_name(node) == "split": +def get_node_shape(node: Node) -> Any: + """ + return node data shape + """ + if get_node_name(node) in ["split", "unbind"]: return node.meta["tensor_meta"][0].shape if hasattr(node.meta["tensor_meta"], "shape"): return node.meta["tensor_meta"].shape diff --git a/colossalai/fx/profiler/opcount.py b/colossalai/fx/profiler/opcount.py index 6bdec865f..e302c8421 100644 --- a/colossalai/fx/profiler/opcount.py +++ b/colossalai/fx/profiler/opcount.py @@ -359,7 +359,8 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'): aten.where.self, aten.zero_.default, aten.zeros_like.default, - aten.fill_.Scalar + aten.fill_.Scalar, + aten.stack.default ] # yapf: disable for op in zero_flop_aten: diff --git a/tests/test_autochunk/test_autochunk_diffuser/benchmark_autochunk_diffuser.py b/tests/test_autochunk/test_autochunk_diffuser/benchmark_autochunk_diffuser.py new file mode 100644 index 000000000..5c127bd69 --- /dev/null +++ b/tests/test_autochunk/test_autochunk_diffuser/benchmark_autochunk_diffuser.py @@ -0,0 +1,147 @@ +import time +from typing import Any, Dict, List + +import torch +import torch.fx + +import colossalai +from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE +from colossalai.fx.graph_module import ColoGraphModule +from colossalai.fx.passes.meta_info_prop import MetaInfoProp +from colossalai.fx.profiler import parameter_size +from colossalai.utils import free_port + +if AUTOCHUNK_AVAILABLE: + from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen + from colossalai.fx.profiler import MetaTensor + from colossalai.fx.tracer.experimental import ColoTracer, symbolic_trace + + +def _benchmark_autochunk_unet_gm( + model: Any, + data: tuple, + max_memory: int = None, +) -> None: + model = model.cuda().eval() + + # build model and input + meta_args, concrete_args = data + if concrete_args is None: + concrete_args = {} + + # trace the meta graph and setup codegen + meta_graph = symbolic_trace( + model, + meta_args={k: v.to(torch.device("meta")) for k, v in meta_args}, + concrete_args={k: v for k, v in concrete_args}, + ) + model = model.cuda().eval() + interp = MetaInfoProp(meta_graph) + meta_tensors = [i[1] for i in meta_args] + [i[1] for i in concrete_args] + meta_tensors = [MetaTensor(i, fake_device="cuda:0") if isinstance(i, torch.Tensor) else i for i in meta_tensors] + interp.propagate(*meta_tensors) + codegen = AutoChunkCodeGen( + meta_graph, + max_memory=max_memory, + ) + + # trace and recompile + # MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer + graph = ColoTracer().trace( + model.cuda().eval(), + meta_args={k: v.to(torch.device("meta")) for k, v in meta_args}, + concrete_args={k: v for k, v in concrete_args}, + ) + graph.set_codegen(codegen) + gm = ColoGraphModule(model, graph, ckpt_codegen=False) + gm.recompile() + + # init inputs + inputs = [i[1] for i in meta_args] + [i[1] for i in concrete_args] + inputs = [i.cuda() if isinstance(i, torch.Tensor) else i for i in inputs] + model.cuda().eval() + + # bench + para_mem = float(parameter_size(model)) / 1024**2 + act_mem = _benchmark_memory(gm, inputs) + speed = _benchmark_speed(gm, inputs) + print("unet autochunk, time: %.4fs, act mem: %.2fMB, para mem: %.2fMB, all mem: %.2fMB" % + (speed, act_mem, para_mem, act_mem + para_mem)) + + +def _benchmark_autochunk_unet_origin( + model: Any, + data: tuple, +) -> None: + # build model and input + meta_args, concrete_args = data + if concrete_args is None: + concrete_args = {} + + # init inputs + inputs = [i[1] for i in meta_args] + [i[1] for i in concrete_args] + inputs = [i.cuda() if isinstance(i, torch.Tensor) else i for i in inputs] + model.cuda().eval() + + # bench + para_mem = float(parameter_size(model)) / 1024**2 + act_mem = _benchmark_memory(model, inputs) + speed = _benchmark_speed(model, inputs) + print("unet origin, time: %.4fs, act mem: %.2fMB, para mem: %.2fMB, all mem: %.2fMB" % + (speed, act_mem, para_mem, act_mem + para_mem)) + return act_mem + + +def _benchmark_memory(model, inputs): + with torch.no_grad(): + torch.cuda.reset_peak_memory_stats() + now_mem = float(torch.cuda.memory_allocated()) / 1024**2 + model(*inputs) + new_max_mem = float(torch.cuda.max_memory_allocated()) / 1024**2 + return new_max_mem - now_mem + + +def _benchmark_speed(model, inputs, loop=5): + with torch.no_grad(): + for _ in range(loop // 2 + 1): + model(*inputs) + torch.cuda.synchronize() + time1 = time.time() + for _ in range(loop): + model(*inputs) + torch.cuda.synchronize() + time2 = time.time() + return (time2 - time1) / loop + + +def benchmark_autochunk_unet(batch=1, height=448, width=448): + from test_autochunk_unet import UNet2DModel, get_data + model = UNet2DModel() + latent_shape = (batch, 3, height // 7, width // 7) + + print("\nbatch: %d, height: %d, width: %d" % (batch, height, width)) + max_mem = _benchmark_autochunk_unet_origin(model, get_data(latent_shape)) + for ratio in [0.5, 0.4, 0.3, 0.2]: + try: + _benchmark_autochunk_unet_gm(model, get_data(latent_shape), max_mem * ratio) + except RuntimeError as e: + if e.args[0] == 'Search failed. Try a larger memory threshold.': + break + except Exception as e: + raise e + _benchmark_autochunk_unet_gm(model, get_data(latent_shape), None) + + +if __name__ == "__main__": + # launch colossalai + colossalai.launch( + config={}, + rank=0, + world_size=1, + host="localhost", + port=free_port(), + backend="nccl", + ) + benchmark_autochunk_unet(batch=1, height=224 * 2, width=224 * 2) + benchmark_autochunk_unet(batch=1, height=224 * 3, width=224 * 3) + benchmark_autochunk_unet(batch=1, height=224 * 4, width=224 * 4) diff --git a/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_unet.py b/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_unet.py index 518c7f451..16c5b10ff 100644 --- a/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_unet.py +++ b/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_unet.py @@ -39,7 +39,7 @@ def get_data(shape: tuple) -> Tuple[List, List]: ) @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("shape", [LATENTS_SHAPE]) -@pytest.mark.parametrize("max_memory", [None]) +@pytest.mark.parametrize("max_memory", [None, 150, 300]) def test_evoformer_block(model, shape, max_memory): run_func = partial( run_test, @@ -57,7 +57,7 @@ if __name__ == "__main__": max_memory=None, model=UNet2DModel, print_code=False, - print_mem=False, + print_mem=True, print_est_mem=False, print_progress=False, ) diff --git a/tests/test_autochunk/test_autochunk_vit/test_autochunk_vit.py b/tests/test_autochunk/test_autochunk_vit/test_autochunk_vit.py new file mode 100644 index 000000000..2b7cbf139 --- /dev/null +++ b/tests/test_autochunk/test_autochunk_vit/test_autochunk_vit.py @@ -0,0 +1,53 @@ +from functools import partial +from typing import List, Tuple + +import pytest +import torch +import torch.multiprocessing as mp + +try: + from timm.models.vision_transformer import vit_large_patch16_384 as vit + MODELS = [vit] + HAS_REPO = True +except: + MODELS = [] + HAS_REPO = False + +from test_autochunk_vit_utils import run_test + +from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE + + +def get_data() -> Tuple[List, List]: + data = torch.rand(1, 3, 384, 384) + meta_args = {'x': data} + return data, meta_args + + +@pytest.mark.skipif( + not (AUTOCHUNK_AVAILABLE and HAS_REPO), + reason="torch version is lower than 1.12.0", +) +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("max_memory", [None, 32, 40]) +def test_evoformer_block(model, max_memory): + run_func = partial( + run_test, + max_memory=max_memory, + model=model, + data=get_data(), + ) + mp.spawn(run_func, nprocs=1) + + +if __name__ == "__main__": + run_test( + rank=0, + data=get_data(), + max_memory=None, + model=vit, + print_code=False, + print_mem=False, + print_est_mem=False, + print_progress=False, + ) diff --git a/tests/test_autochunk/test_autochunk_vit/test_autochunk_vit_utils.py b/tests/test_autochunk/test_autochunk_vit/test_autochunk_vit_utils.py new file mode 100644 index 000000000..035dd5979 --- /dev/null +++ b/tests/test_autochunk/test_autochunk_vit/test_autochunk_vit_utils.py @@ -0,0 +1,128 @@ +from typing import Any, Dict, List + +import torch +import torch.fx + +import colossalai +from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE +from colossalai.core import global_context as gpc +from colossalai.fx.graph_module import ColoGraphModule +from colossalai.fx.passes.meta_info_prop import MetaInfoProp +from colossalai.utils import free_port + +if AUTOCHUNK_AVAILABLE: + from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen + from colossalai.fx.profiler import MetaTensor + from colossalai.fx.tracer.experimental import ColoTracer, symbolic_trace + + +def assert_codegen_run( + model: Any, + meta_args: Dict, + data: Any, + max_memory: int = None, + print_mem: bool = False, + print_est_mem: bool = False, + print_progress: bool = False, + print_code: bool = False, +) -> List[Dict]: + model = model() + + # trace the meta graph and setup codegen + meta_graph = symbolic_trace(model, meta_args={k: v.to(torch.device("meta")) for k, v in meta_args.items()}) + model = model.cuda().eval() + interp = MetaInfoProp(meta_graph) + meta_tensors = [MetaTensor(i[1], fake_device="cuda:0") for i in meta_args.items()] + interp.propagate(*meta_tensors) + codegen = AutoChunkCodeGen( + meta_graph, + max_memory=max_memory, + print_mem=print_est_mem, + print_progress=print_progress, + ) + chunks = codegen.chunk_infos + + # trace and recompile + # MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer + graph = ColoTracer().trace( + model.cuda(), + meta_args={k: v.to(torch.device("meta")) for k, v in meta_args.items()}, + ) + graph.set_codegen(codegen) + gm = ColoGraphModule(model, graph, ckpt_codegen=False) + gm.recompile() + + # assert chunk in code + code = graph.python_code("self").src + if print_code: + print(code) + assert "chunk_size = None; " in code + + # assert result + inputs = [data.cuda()] + model.cuda().eval() + gm.eval() + with torch.no_grad(): + if print_mem: + torch.cuda.reset_peak_memory_stats() + now_mem_gm = torch.cuda.memory_allocated() / 1024**2 + out_gm = gm(*[i.clone() if isinstance(i, torch.Tensor) else i for i in inputs]) + if print_mem: + max_mem_gm = torch.cuda.max_memory_allocated() / 1024**2 + torch.cuda.reset_peak_memory_stats() + now_mem_ori = torch.cuda.memory_allocated() / 1024**2 + out_model = model(*[i.clone() if isinstance(i, torch.Tensor) else i for i in inputs]) + if print_mem: + max_mem_ori = torch.cuda.max_memory_allocated() / 1024**2 + print("origin mem: %.2fMB, autochunk mem: %.2fMB" % (max_mem_ori - now_mem_ori, max_mem_gm - now_mem_gm)) + + assert torch.allclose(out_gm, out_model, + atol=1e-3), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean( + torch.abs(out_gm - out_model)) + + return chunks + + +def run_test( + rank: int, + model: Any, + data: tuple, + max_memory: int, + print_code: bool = False, + print_mem: bool = False, + print_est_mem: bool = False, + print_progress: bool = False, + get_chunk_target: Any = None, +) -> None: + # launch colossalai + colossalai.launch( + config={}, + rank=rank, + world_size=1, + host="localhost", + port=free_port(), + backend="nccl", + ) + + # build model and input + data, meta_args = data + chunks = assert_codegen_run( + model, + meta_args=meta_args, + data=data, + max_memory=max_memory, + print_code=print_code, + print_mem=print_mem, + print_est_mem=print_est_mem, + print_progress=print_progress, + ) + + if get_chunk_target is not None: + chunk_found = [i["region"] for i in chunks] + chunk_target = get_chunk_target()[max_memory] + assert (chunk_found == chunk_target), "found regions %s doesn't equal target regions %s" % ( + str(chunk_found), + str(chunk_target), + ) + + gpc.destroy()