diff --git a/colossalai/fx/profiler/constants.py b/colossalai/fx/profiler/constants.py index 38214e219..5763a46dc 100644 --- a/colossalai/fx/profiler/constants.py +++ b/colossalai/fx/profiler/constants.py @@ -1,6 +1,6 @@ import torch -__all__ = ['ALIAS_ATEN', 'INPLACE_NEW', 'INPLACE_MATH_ATEN', 'CLONE_ATEN'] +__all__ = ['ALIAS_ATEN', 'INPLACE_NEW', 'INPLACE_MATH_ATEN', 'CLONE_ATEN', 'RELU_LIKE_OPS', 'RELU_LIKE_MOD'] aten = torch.ops.aten @@ -30,3 +30,15 @@ INPLACE_MATH_ATEN = [ CLONE_ATEN = [ aten.clone.default, ] + +# See illustrations in +# https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/fx/profiler/constants.py +OUTPUT_SAVED_OPS = [ + torch.nn.functional.relu, + torch.nn.functional.softmax, +] + +OUTPUT_SAVED_MOD = [ + torch.nn.ReLU, + torch.nn.Softmax, +] diff --git a/colossalai/fx/profiler/memory.py b/colossalai/fx/profiler/memory.py index 1a3f127f1..2e8b5d51b 100644 --- a/colossalai/fx/profiler/memory.py +++ b/colossalai/fx/profiler/memory.py @@ -5,6 +5,9 @@ from torch.fx import GraphModule, Node from .._compatibility import compatibility, is_compatible_with_meta +if is_compatible_with_meta(): + from .constants import OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS + __all__ = [ 'activation_size', 'parameter_size', 'is_inplace', "calculate_fwd_in", "calculate_fwd_tmp", "calculate_fwd_out" ] @@ -71,14 +74,35 @@ def calculate_fwd_tmp(n: Node) -> int: fwd_tmp (int): the result of `fwd_tmp` """ - def is_relu_node(n: Node) -> bool: + def is_relu_like_node(n: Node) -> bool: + """Check if a node is a ReLU-like node. + ReLU-like nodes have the following properties: + - They are either `call_function` or `call_module` + - Their output tensors are directly saved for backward + - Their input tensors are not saved for backward + + An example is `torch.nn.functional.softmax` which has (forward + backward): + def forward(self, input_2): + _softmax_default = torch.ops.aten._softmax.default(input_2, None, None); input_2 = None + zeros_like_default = torch.ops.aten.zeros_like.default(_softmax_default, dtype = None, layout = None, device = None, pin_memory = None) + detach_default = torch.ops.aten.detach.default(_softmax_default); _softmax_default = None + _softmax_backward_data_default = torch.ops.aten._softmax_backward_data.default(zeros_like_default, detach_default, None, None); zeros_like_default = detach_default = None + detach_default_1 = torch.ops.aten.detach.default(_softmax_backward_data_default); _softmax_backward_data_default = None + detach_default_2 = torch.ops.aten.detach.default(detach_default_1); detach_default_1 = None + + Args: + n (Node): A node from the graph + + Returns: + bool: Whether the node is a ReLU-like node + """ if n.op == 'call_function': - return n.target in [torch.nn.functional.relu] + return n.target in OUTPUT_SAVED_OPS elif n.op == 'call_module': - return type(n.graph.owning_module.get_submodule(n.target)) in [torch.nn.ReLU] + return type(n.graph.owning_module.get_submodule(n.target)) in OUTPUT_SAVED_MOD return False - if not is_relu_node(n): + if not is_relu_like_node(n): return activation_size(n.meta["fwd_tmp"]) return 0 diff --git a/colossalai/fx/profiler/profiler.py b/colossalai/fx/profiler/profiler.py index 608cc9e4d..2fa5c41c0 100644 --- a/colossalai/fx/profiler/profiler.py +++ b/colossalai/fx/profiler/profiler.py @@ -9,7 +9,7 @@ from torch.nn.parameter import Parameter from torch.utils._pytree import tree_map from .._compatibility import compatibility -from .constants import ALIAS_ATEN +from .constants import ALIAS_ATEN, OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS from .dataflow import GraphInfo, Phase, autograd_graph_analysis, is_phase from .memory import activation_size, parameter_size from .opcount import flop_mapping @@ -272,7 +272,8 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G tensor = x._tensor.detach() tensor.uuid = x._tensor.uuid return tensor - return x + if not isinstance(x, torch.finfo): + return x graph_info.fwd_out = list(map(extract_tensor, normalize_tuple(out))) @@ -314,21 +315,17 @@ def profile_function(target: 'Target', device: str = 'meta') -> Callable: # If there is an argument that this `call_function` is inplace, we should # still run the profiling but discard some results regarding `target` global do_not_cache + inplace = kwargs.get('inplace', False) - if inplace or target in [torch.nn.functional.relu]: + if target in OUTPUT_SAVED_OPS: + do_not_cache = True + if inplace: do_not_cache = True kwargs['inplace'] = False if device == 'meta': out, meta = _profile_meta(func, *args, **kwargs) - # currently we set the fwd_mem_tmp of ReLU to zero - if target in [torch.nn.functional.relu]: - meta.fwd_in = [] - meta.fwd_tmp = [] - meta.bwd_mem_out = 0 - meta.fwd_mem_tmp = 0 else: out, meta = _profile_concrete(func, *args, **kwargs) - if inplace: kwargs['inplace'] = True do_not_cache = False @@ -386,20 +383,16 @@ def profile_module(module: torch.nn.Module, device: str = 'meta') -> Callable: global do_not_cache inplace = getattr(module, 'inplace', False) - if inplace or type(module) in [torch.nn.ReLU]: + if type(module) in OUTPUT_SAVED_MOD: + do_not_cache = True + if inplace: do_not_cache = True module.inplace = False if device == 'meta': out, meta = _profile_meta(func, *args, **kwargs) - # currently we set the fwd_tmp of ReLU to [] - if type(module) in [torch.nn.ReLU]: - meta.fwd_in = [] - meta.fwd_tmp = [] - meta.bwd_mem_out = 0 else: out, meta = _profile_concrete(func, *args, **kwargs) if inplace: - module.inplace = True do_not_cache = False diff --git a/colossalai/fx/profiler/tensor.py b/colossalai/fx/profiler/tensor.py index d4a078c2a..3be3dd65c 100644 --- a/colossalai/fx/profiler/tensor.py +++ b/colossalai/fx/profiler/tensor.py @@ -125,5 +125,5 @@ class MetaTensor(torch.Tensor): device = kwargs['device'] result = super().to(*args, **kwargs) if device is not None: - result = MetaTensor(deepcopy(result), fake_device=device) + result = MetaTensor(result, fake_device=device) return result diff --git a/tests/test_fx/test_profiler/gpt_utils.py b/tests/test_fx/test_profiler/gpt_utils.py new file mode 100644 index 000000000..aec322684 --- /dev/null +++ b/tests/test_fx/test_profiler/gpt_utils.py @@ -0,0 +1,50 @@ +import torch +import torch.nn as nn +from transformers import GPT2Config, GPT2LMHeadModel + + +class GPTLMModel(nn.Module): + + def __init__(self, + hidden_size=768, + num_layers=12, + num_attention_heads=12, + max_seq_len=1024, + vocab_size=50257, + checkpoint=False): + super().__init__() + self.checkpoint = checkpoint + self.model = GPT2LMHeadModel( + GPT2Config(n_embd=hidden_size, + n_layer=num_layers, + n_head=num_attention_heads, + n_positions=max_seq_len, + n_ctx=max_seq_len, + vocab_size=vocab_size)) + if checkpoint: + self.model.gradient_checkpointing_enable() + + def forward(self, input_ids, attention_mask): + # Only return lm_logits + return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0] + + +class GPTLMLoss(nn.Module): + + def __init__(self): + super().__init__() + self.loss_fn = nn.CrossEntropyLoss() + + def forward(self, logits, labels): + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + +def gpt2_medium(checkpoint=False): + return GPTLMModel(hidden_size=1024, num_layers=24, num_attention_heads=16, checkpoint=checkpoint) + + +def gpt2_xl(checkpoint=False): + return GPTLMModel(hidden_size=1600, num_layers=48, num_attention_heads=32, checkpoint=checkpoint) diff --git a/tests/test_fx/test_profiler/test_profiler_meta_info_prop.py b/tests/test_fx/test_profiler/test_profiler_meta_info_prop.py new file mode 100644 index 000000000..a9921af3c --- /dev/null +++ b/tests/test_fx/test_profiler/test_profiler_meta_info_prop.py @@ -0,0 +1,181 @@ +from typing import Optional, Tuple, Union + +import torch +import torch.fx +import torchvision.models as tm +from colossalai.fx.passes.meta_info_prop import MetaInfoProp +from colossalai.fx.profiler import (calculate_fwd_out, calculate_fwd_tmp, is_compatible_with_meta, parameter_size) +from colossalai.fx.tracer.tracer import ColoTracer +from colossalai.testing.pytest_wrapper import run_on_environment_flag +from gpt_utils import gpt2_medium, gpt2_xl +from torch.fx import symbolic_trace + +if is_compatible_with_meta(): + from colossalai.fx.profiler import MetaTensor + +TM_BATCH_SIZE = 64 +GPT_BATCH_SIZE = 8 +NUM_STEPS = 5 + + +def extract_forward_mem(gm: torch.fx.GraphModule): + node_size = 0 + param_size = 0 + for node in gm.graph.nodes: + node_size += calculate_fwd_tmp(node) + node_size += calculate_fwd_out(node) + param_size = parameter_size(gm) + return (node_size + param_size) / 1024**2, param_size / 1024**2 + + +def extract_forward_flops(gm: torch.fx.GraphModule): + fwd_flop = 0 + bwd_flop = 0 + for node in gm.graph.nodes: + fwd_flop += node.meta.get('fwd_flop', 0) + bwd_flop += node.meta.get('bwd_flop', 0) + return fwd_flop, bwd_flop + + +def gen_tm_data(batch_size: int, shape: Tuple[int, int, int], device='cuda'): + data = torch.rand(batch_size, *shape, device=device) + label = torch.empty(batch_size, dtype=torch.long, device=device).random_(1000) + return data, label + + +def gen_gpt_data(batch_size, seq_len, vocab_size, device='cpu'): + input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=device) + attention_mask = torch.ones_like(input_ids, device=device) + return input_ids, attention_mask + + +def run_tm_forward(gm: torch.fx.GraphModule): + torch.cuda.reset_peak_memory_stats() + forward_mem = -torch.cuda.memory_allocated(device="cuda:0") / 1024**2 + param_mem = -torch.cuda.memory_allocated(device="cuda:0") / 1024**2 + gm.cuda() + param_mem += torch.cuda.memory_allocated(device="cuda:0") / 1024**2 + gm.train() + for n in range(NUM_STEPS): + torch.cuda.reset_peak_memory_stats() + data, _ = gen_tm_data(TM_BATCH_SIZE, (3, 224, 224)) + + # If we need to dive deep into the memory usage by + # inspecting `saved_tensor_hooks` + + # ===================================================== + # fwd_mem = 0 + # cache = set() + # def pack(x): + # if isinstance(x, torch.Tensor): + # nonlocal fwd_mem, cache + # if x.data_ptr() not in cache: + # fwd_mem += activation_size(x) + # cache.add(x.data_ptr()) + # return x + # def unpack(x): + # return x + # + # with torch.autograd.graph.saved_tensors_hooks(pack, unpack): + # output = gm(data) + # print(f'Memory estimation by saved_tensor_hooks: {fwd_mem / 1024**2}') + # ===================================================== + + output = gm(data) + forward_mem += torch.cuda.memory_allocated(device="cuda:0") / 1024**2 / NUM_STEPS + del output + return forward_mem, param_mem + + +def run_gpt_forward(gm: torch.fx.GraphModule): + torch.cuda.reset_peak_memory_stats() + forward_mem = -torch.cuda.memory_allocated(device="cuda:0") / 1024**2 + param_mem = -torch.cuda.memory_allocated(device="cuda:0") / 1024**2 + gm.cuda() + param_mem += torch.cuda.memory_allocated(device="cuda:0") / 1024**2 + for n in range(NUM_STEPS): + torch.cuda.reset_peak_memory_stats() + data, mask = gen_gpt_data(GPT_BATCH_SIZE, 1024, 50257, device='cuda:0') + + # If we need to dive deep into the memory usage by + # inspecting `saved_tensor_hooks` + + # ===================================================== + # fwd_mem = 0 + # cache = set() + # def pack(x): + # if isinstance(x, torch.Tensor): + # nonlocal fwd_mem, cache + # if x.data_ptr() not in cache: + # fwd_mem += activation_size(x) + # cache.add(x.data_ptr()) + # return x + # def unpack(x): + # return x + # + # with torch.autograd.graph.saved_tensors_hooks(pack, unpack): + # output = gm(data, mask) + # print(f'Memory estimation by saved_tensor_hooks: {fwd_mem / 1024**2}') + # ===================================================== + + output = gm(data, mask) + forward_mem += torch.cuda.memory_allocated(device="cuda:0") / 1024**2 / NUM_STEPS + del output + return forward_mem, param_mem + + +@run_on_environment_flag(name='FX_PROFILER') +def test_meta_info_prop(): + for m in [ + tm.alexnet, tm.resnet18, tm.resnet34, tm.resnet50, tm.resnet101, tm.resnet152, tm.densenet121, + tm.densenet161, tm.densenet169, tm.densenet201, tm.convnext_tiny, tm.convnext_small, tm.convnext_base, + tm.convnext_large, tm.wide_resnet50_2, tm.wide_resnet101_2, tm.regnet_x_16gf, tm.mnasnet0_5, + tm.efficientnet_b0, tm.shufflenet_v2_x0_5, tm.shufflenet_v2_x1_0, tm.shufflenet_v2_x1_5, + tm.shufflenet_v2_x2_0, tm.mobilenet_v2, tm.mobilenet_v3_small, tm.mobilenet_v3_large, tm.resnext50_32x4d, + tm.resnext101_32x8d, tm.resnext101_64x4d, tm.vit_b_16, tm.vit_b_32, tm.vit_h_14, tm.vit_l_16, tm.vit_l_32, + tm.vgg11, tm.vgg11_bn, tm.vgg13, tm.vgg13_bn, tm.vgg16, tm.vgg16_bn, tm.vgg19, tm.vgg19_bn + ]: + model = m().cuda() + model.train() + data = MetaTensor(torch.rand(int(TM_BATCH_SIZE), 3, 224, 224, device='meta'), fake_device='cuda:0') + gm = symbolic_trace(model) + interp = MetaInfoProp(gm) + interp.propagate(data) + gm.cpu() + + meta_forward_mem, meta_param_mem = extract_forward_mem(gm) + fwd_flop, bwd_flop = extract_forward_flops(gm) + concrete_forward_mem, concrete_param_mem = run_tm_forward(gm) + + print( + f'|{m.__name__}|{meta_forward_mem:.3f} MB|{meta_param_mem:.3f} MB|{concrete_forward_mem:.3f} MB|{concrete_param_mem:.3f} MB|fwd_flop={fwd_flop / 1e9:.3f}GFLOPs|bwd_flop={bwd_flop / 1e9:.3f}GFLOPs|' + ) + del model, gm + + +@run_on_environment_flag(name='FX_PROFILER') +def test_gpt_meta_info_prop(): + for m in [gpt2_medium]: + model = m().cuda() + model.train() + data, mask = gen_gpt_data(GPT_BATCH_SIZE, 1024, 50257, device='meta') + graph = ColoTracer().trace(model, meta_args={'input_ids': data, 'attention_mask': mask}) + gm = torch.fx.GraphModule(model, graph) + interp = MetaInfoProp(gm) + interp.propagate(MetaTensor(data, fake_device='cuda:0'), MetaTensor(mask, fake_device='cuda:0')) + model.cpu() + + fwd_flop, bwd_flop = extract_forward_flops(gm) + + concrete_forward_mem, concrete_param_mem = run_gpt_forward(gm) + meta_forward_mem, meta_param_mem = extract_forward_mem(gm) + + print( + f'|{m.__name__}|{meta_forward_mem:.3f} MB|{meta_param_mem:.3f} MB|{concrete_forward_mem:.3f} MB|{concrete_param_mem:.3f} MB|fwd_flop={fwd_flop / 1e9:.3f}GFLOPs|bwd_flop={bwd_flop / 1e9:.3f}GFLOPs|' + ) + del model, gm + + +if __name__ == '__main__': + test_meta_info_prop() + test_gpt_meta_info_prop()