mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 06:30:41 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -1,26 +1,29 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import GPT2Config, GPT2LMHeadModel
|
||||
|
||||
|
||||
class GPTLMModel(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
hidden_size=768,
|
||||
num_layers=12,
|
||||
num_attention_heads=12,
|
||||
max_seq_len=1024,
|
||||
vocab_size=50257,
|
||||
checkpoint=False):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size=768,
|
||||
num_layers=12,
|
||||
num_attention_heads=12,
|
||||
max_seq_len=1024,
|
||||
vocab_size=50257,
|
||||
checkpoint=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.checkpoint = checkpoint
|
||||
self.model = GPT2LMHeadModel(
|
||||
GPT2Config(n_embd=hidden_size,
|
||||
n_layer=num_layers,
|
||||
n_head=num_attention_heads,
|
||||
n_positions=max_seq_len,
|
||||
n_ctx=max_seq_len,
|
||||
vocab_size=vocab_size))
|
||||
GPT2Config(
|
||||
n_embd=hidden_size,
|
||||
n_layer=num_layers,
|
||||
n_head=num_attention_heads,
|
||||
n_positions=max_seq_len,
|
||||
n_ctx=max_seq_len,
|
||||
vocab_size=vocab_size,
|
||||
)
|
||||
)
|
||||
if checkpoint:
|
||||
self.model.gradient_checkpointing_enable()
|
||||
|
||||
@@ -30,7 +33,6 @@ class GPTLMModel(nn.Module):
|
||||
|
||||
|
||||
class GPTLMLoss(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.loss_fn = nn.CrossEntropyLoss()
|
||||
|
@@ -1,9 +1,9 @@
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
import torchvision.models as tm
|
||||
from gpt_utils import gpt2_medium, gpt2_xl
|
||||
from gpt_utils import gpt2_medium
|
||||
from torch.fx import symbolic_trace
|
||||
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
@@ -33,18 +33,18 @@ def extract_forward_flops(gm: torch.fx.GraphModule):
|
||||
fwd_flop = 0
|
||||
bwd_flop = 0
|
||||
for node in gm.graph.nodes:
|
||||
fwd_flop += node.meta.get('fwd_flop', 0)
|
||||
bwd_flop += node.meta.get('bwd_flop', 0)
|
||||
fwd_flop += node.meta.get("fwd_flop", 0)
|
||||
bwd_flop += node.meta.get("bwd_flop", 0)
|
||||
return fwd_flop, bwd_flop
|
||||
|
||||
|
||||
def gen_tm_data(batch_size: int, shape: Tuple[int, int, int], device='cuda'):
|
||||
def gen_tm_data(batch_size: int, shape: Tuple[int, int, int], device="cuda"):
|
||||
data = torch.rand(batch_size, *shape, device=device)
|
||||
label = torch.empty(batch_size, dtype=torch.long, device=device).random_(1000)
|
||||
return data, label
|
||||
|
||||
|
||||
def gen_gpt_data(batch_size, seq_len, vocab_size, device='cpu'):
|
||||
def gen_gpt_data(batch_size, seq_len, vocab_size, device="cpu"):
|
||||
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
|
||||
attention_mask = torch.ones_like(input_ids, device=device)
|
||||
return input_ids, attention_mask
|
||||
@@ -96,7 +96,7 @@ def run_gpt_forward(gm: torch.fx.GraphModule):
|
||||
param_mem += torch.cuda.memory_allocated(device="cuda:0") / 1024**2
|
||||
for n in range(NUM_STEPS):
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
data, mask = gen_gpt_data(GPT_BATCH_SIZE, 1024, 50257, device='cuda:0')
|
||||
data, mask = gen_gpt_data(GPT_BATCH_SIZE, 1024, 50257, device="cuda:0")
|
||||
|
||||
# If we need to dive deep into the memory usage by
|
||||
# inspecting `saved_tensor_hooks`
|
||||
@@ -125,21 +125,56 @@ def run_gpt_forward(gm: torch.fx.GraphModule):
|
||||
return forward_mem, param_mem
|
||||
|
||||
|
||||
@run_on_environment_flag(name='FX_PROFILER')
|
||||
@run_on_environment_flag(name="FX_PROFILER")
|
||||
@clear_cache_before_run()
|
||||
def test_meta_info_prop():
|
||||
for m in [
|
||||
tm.alexnet, tm.resnet18, tm.resnet34, tm.resnet50, tm.resnet101, tm.resnet152, tm.densenet121,
|
||||
tm.densenet161, tm.densenet169, tm.densenet201, tm.convnext_tiny, tm.convnext_small, tm.convnext_base,
|
||||
tm.convnext_large, tm.wide_resnet50_2, tm.wide_resnet101_2, tm.regnet_x_16gf, tm.mnasnet0_5,
|
||||
tm.efficientnet_b0, tm.shufflenet_v2_x0_5, tm.shufflenet_v2_x1_0, tm.shufflenet_v2_x1_5,
|
||||
tm.shufflenet_v2_x2_0, tm.mobilenet_v2, tm.mobilenet_v3_small, tm.mobilenet_v3_large, tm.resnext50_32x4d,
|
||||
tm.resnext101_32x8d, tm.resnext101_64x4d, tm.vit_b_16, tm.vit_b_32, tm.vit_h_14, tm.vit_l_16, tm.vit_l_32,
|
||||
tm.vgg11, tm.vgg11_bn, tm.vgg13, tm.vgg13_bn, tm.vgg16, tm.vgg16_bn, tm.vgg19, tm.vgg19_bn
|
||||
tm.alexnet,
|
||||
tm.resnet18,
|
||||
tm.resnet34,
|
||||
tm.resnet50,
|
||||
tm.resnet101,
|
||||
tm.resnet152,
|
||||
tm.densenet121,
|
||||
tm.densenet161,
|
||||
tm.densenet169,
|
||||
tm.densenet201,
|
||||
tm.convnext_tiny,
|
||||
tm.convnext_small,
|
||||
tm.convnext_base,
|
||||
tm.convnext_large,
|
||||
tm.wide_resnet50_2,
|
||||
tm.wide_resnet101_2,
|
||||
tm.regnet_x_16gf,
|
||||
tm.mnasnet0_5,
|
||||
tm.efficientnet_b0,
|
||||
tm.shufflenet_v2_x0_5,
|
||||
tm.shufflenet_v2_x1_0,
|
||||
tm.shufflenet_v2_x1_5,
|
||||
tm.shufflenet_v2_x2_0,
|
||||
tm.mobilenet_v2,
|
||||
tm.mobilenet_v3_small,
|
||||
tm.mobilenet_v3_large,
|
||||
tm.resnext50_32x4d,
|
||||
tm.resnext101_32x8d,
|
||||
tm.resnext101_64x4d,
|
||||
tm.vit_b_16,
|
||||
tm.vit_b_32,
|
||||
tm.vit_h_14,
|
||||
tm.vit_l_16,
|
||||
tm.vit_l_32,
|
||||
tm.vgg11,
|
||||
tm.vgg11_bn,
|
||||
tm.vgg13,
|
||||
tm.vgg13_bn,
|
||||
tm.vgg16,
|
||||
tm.vgg16_bn,
|
||||
tm.vgg19,
|
||||
tm.vgg19_bn,
|
||||
]:
|
||||
model = m().cuda()
|
||||
model.train()
|
||||
data = MetaTensor(torch.rand(int(TM_BATCH_SIZE), 3, 224, 224, device='meta'), fake_device='cuda:0')
|
||||
data = MetaTensor(torch.rand(int(TM_BATCH_SIZE), 3, 224, 224, device="meta"), fake_device="cuda:0")
|
||||
gm = symbolic_trace(model)
|
||||
interp = MetaInfoProp(gm)
|
||||
interp.propagate(data)
|
||||
@@ -150,22 +185,22 @@ def test_meta_info_prop():
|
||||
concrete_forward_mem, concrete_param_mem = run_tm_forward(gm)
|
||||
|
||||
print(
|
||||
f'|{m.__name__}|{meta_forward_mem:.3f} MB|{meta_param_mem:.3f} MB|{concrete_forward_mem:.3f} MB|{concrete_param_mem:.3f} MB|fwd_flop={fwd_flop / 1e9:.3f}GFLOPs|bwd_flop={bwd_flop / 1e9:.3f}GFLOPs|'
|
||||
f"|{m.__name__}|{meta_forward_mem:.3f} MB|{meta_param_mem:.3f} MB|{concrete_forward_mem:.3f} MB|{concrete_param_mem:.3f} MB|fwd_flop={fwd_flop / 1e9:.3f}GFLOPs|bwd_flop={bwd_flop / 1e9:.3f}GFLOPs|"
|
||||
)
|
||||
del model, gm
|
||||
|
||||
|
||||
@run_on_environment_flag(name='FX_PROFILER')
|
||||
@run_on_environment_flag(name="FX_PROFILER")
|
||||
@clear_cache_before_run()
|
||||
def test_gpt_meta_info_prop():
|
||||
for m in [gpt2_medium]:
|
||||
model = m().cuda()
|
||||
model.train()
|
||||
data, mask = gen_gpt_data(GPT_BATCH_SIZE, 1024, 50257, device='meta')
|
||||
graph = ColoTracer().trace(model, meta_args={'input_ids': data, 'attention_mask': mask})
|
||||
data, mask = gen_gpt_data(GPT_BATCH_SIZE, 1024, 50257, device="meta")
|
||||
graph = ColoTracer().trace(model, meta_args={"input_ids": data, "attention_mask": mask})
|
||||
gm = torch.fx.GraphModule(model, graph)
|
||||
interp = MetaInfoProp(gm)
|
||||
interp.propagate(MetaTensor(data, fake_device='cuda:0'), MetaTensor(mask, fake_device='cuda:0'))
|
||||
interp.propagate(MetaTensor(data, fake_device="cuda:0"), MetaTensor(mask, fake_device="cuda:0"))
|
||||
model.cpu()
|
||||
|
||||
fwd_flop, bwd_flop = extract_forward_flops(gm)
|
||||
@@ -174,11 +209,11 @@ def test_gpt_meta_info_prop():
|
||||
meta_forward_mem, meta_param_mem = extract_forward_mem(gm)
|
||||
|
||||
print(
|
||||
f'|{m.__name__}|{meta_forward_mem:.3f} MB|{meta_param_mem:.3f} MB|{concrete_forward_mem:.3f} MB|{concrete_param_mem:.3f} MB|fwd_flop={fwd_flop / 1e9:.3f}GFLOPs|bwd_flop={bwd_flop / 1e9:.3f}GFLOPs|'
|
||||
f"|{m.__name__}|{meta_forward_mem:.3f} MB|{meta_param_mem:.3f} MB|{concrete_forward_mem:.3f} MB|{concrete_param_mem:.3f} MB|fwd_flop={fwd_flop / 1e9:.3f}GFLOPs|bwd_flop={bwd_flop / 1e9:.3f}GFLOPs|"
|
||||
)
|
||||
del model, gm
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
test_meta_info_prop()
|
||||
test_gpt_meta_info_prop()
|
||||
|
Reference in New Issue
Block a user