mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-06 12:07:00 +00:00
[fx/profiler] debug the fx.profiler / add an example test script for fx.profiler (#1730)
* [fx/profiler] add test. * [fx] fix file names. * [fx] add docstring and comment. * [fx] polish profiler.py. * [fx] fix import errors. * [fx] fix profiler. * [fx] fix names.
This commit is contained in:
parent
eee84908d4
commit
30874f1692
@ -1,6 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
__all__ = ['ALIAS_ATEN', 'INPLACE_NEW', 'INPLACE_MATH_ATEN', 'CLONE_ATEN']
|
__all__ = ['ALIAS_ATEN', 'INPLACE_NEW', 'INPLACE_MATH_ATEN', 'CLONE_ATEN', 'RELU_LIKE_OPS', 'RELU_LIKE_MOD']
|
||||||
|
|
||||||
aten = torch.ops.aten
|
aten = torch.ops.aten
|
||||||
|
|
||||||
@ -30,3 +30,15 @@ INPLACE_MATH_ATEN = [
|
|||||||
CLONE_ATEN = [
|
CLONE_ATEN = [
|
||||||
aten.clone.default,
|
aten.clone.default,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# See illustrations in
|
||||||
|
# https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/fx/profiler/constants.py
|
||||||
|
OUTPUT_SAVED_OPS = [
|
||||||
|
torch.nn.functional.relu,
|
||||||
|
torch.nn.functional.softmax,
|
||||||
|
]
|
||||||
|
|
||||||
|
OUTPUT_SAVED_MOD = [
|
||||||
|
torch.nn.ReLU,
|
||||||
|
torch.nn.Softmax,
|
||||||
|
]
|
||||||
|
@ -5,6 +5,9 @@ from torch.fx import GraphModule, Node
|
|||||||
|
|
||||||
from .._compatibility import compatibility, is_compatible_with_meta
|
from .._compatibility import compatibility, is_compatible_with_meta
|
||||||
|
|
||||||
|
if is_compatible_with_meta():
|
||||||
|
from .constants import OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'activation_size', 'parameter_size', 'is_inplace', "calculate_fwd_in", "calculate_fwd_tmp", "calculate_fwd_out"
|
'activation_size', 'parameter_size', 'is_inplace', "calculate_fwd_in", "calculate_fwd_tmp", "calculate_fwd_out"
|
||||||
]
|
]
|
||||||
@ -71,14 +74,35 @@ def calculate_fwd_tmp(n: Node) -> int:
|
|||||||
fwd_tmp (int): the result of `fwd_tmp`
|
fwd_tmp (int): the result of `fwd_tmp`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def is_relu_node(n: Node) -> bool:
|
def is_relu_like_node(n: Node) -> bool:
|
||||||
|
"""Check if a node is a ReLU-like node.
|
||||||
|
ReLU-like nodes have the following properties:
|
||||||
|
- They are either `call_function` or `call_module`
|
||||||
|
- Their output tensors are directly saved for backward
|
||||||
|
- Their input tensors are not saved for backward
|
||||||
|
|
||||||
|
An example is `torch.nn.functional.softmax` which has (forward + backward):
|
||||||
|
def forward(self, input_2):
|
||||||
|
_softmax_default = torch.ops.aten._softmax.default(input_2, None, None); input_2 = None
|
||||||
|
zeros_like_default = torch.ops.aten.zeros_like.default(_softmax_default, dtype = None, layout = None, device = None, pin_memory = None)
|
||||||
|
detach_default = torch.ops.aten.detach.default(_softmax_default); _softmax_default = None
|
||||||
|
_softmax_backward_data_default = torch.ops.aten._softmax_backward_data.default(zeros_like_default, detach_default, None, None); zeros_like_default = detach_default = None
|
||||||
|
detach_default_1 = torch.ops.aten.detach.default(_softmax_backward_data_default); _softmax_backward_data_default = None
|
||||||
|
detach_default_2 = torch.ops.aten.detach.default(detach_default_1); detach_default_1 = None
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n (Node): A node from the graph
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: Whether the node is a ReLU-like node
|
||||||
|
"""
|
||||||
if n.op == 'call_function':
|
if n.op == 'call_function':
|
||||||
return n.target in [torch.nn.functional.relu]
|
return n.target in OUTPUT_SAVED_OPS
|
||||||
elif n.op == 'call_module':
|
elif n.op == 'call_module':
|
||||||
return type(n.graph.owning_module.get_submodule(n.target)) in [torch.nn.ReLU]
|
return type(n.graph.owning_module.get_submodule(n.target)) in OUTPUT_SAVED_MOD
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if not is_relu_node(n):
|
if not is_relu_like_node(n):
|
||||||
return activation_size(n.meta["fwd_tmp"])
|
return activation_size(n.meta["fwd_tmp"])
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
@ -9,7 +9,7 @@ from torch.nn.parameter import Parameter
|
|||||||
from torch.utils._pytree import tree_map
|
from torch.utils._pytree import tree_map
|
||||||
|
|
||||||
from .._compatibility import compatibility
|
from .._compatibility import compatibility
|
||||||
from .constants import ALIAS_ATEN
|
from .constants import ALIAS_ATEN, OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS
|
||||||
from .dataflow import GraphInfo, Phase, autograd_graph_analysis, is_phase
|
from .dataflow import GraphInfo, Phase, autograd_graph_analysis, is_phase
|
||||||
from .memory import activation_size, parameter_size
|
from .memory import activation_size, parameter_size
|
||||||
from .opcount import flop_mapping
|
from .opcount import flop_mapping
|
||||||
@ -272,6 +272,7 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G
|
|||||||
tensor = x._tensor.detach()
|
tensor = x._tensor.detach()
|
||||||
tensor.uuid = x._tensor.uuid
|
tensor.uuid = x._tensor.uuid
|
||||||
return tensor
|
return tensor
|
||||||
|
if not isinstance(x, torch.finfo):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
graph_info.fwd_out = list(map(extract_tensor, normalize_tuple(out)))
|
graph_info.fwd_out = list(map(extract_tensor, normalize_tuple(out)))
|
||||||
@ -314,21 +315,17 @@ def profile_function(target: 'Target', device: str = 'meta') -> Callable:
|
|||||||
# If there is an argument that this `call_function` is inplace, we should
|
# If there is an argument that this `call_function` is inplace, we should
|
||||||
# still run the profiling but discard some results regarding `target`
|
# still run the profiling but discard some results regarding `target`
|
||||||
global do_not_cache
|
global do_not_cache
|
||||||
|
|
||||||
inplace = kwargs.get('inplace', False)
|
inplace = kwargs.get('inplace', False)
|
||||||
if inplace or target in [torch.nn.functional.relu]:
|
if target in OUTPUT_SAVED_OPS:
|
||||||
|
do_not_cache = True
|
||||||
|
if inplace:
|
||||||
do_not_cache = True
|
do_not_cache = True
|
||||||
kwargs['inplace'] = False
|
kwargs['inplace'] = False
|
||||||
if device == 'meta':
|
if device == 'meta':
|
||||||
out, meta = _profile_meta(func, *args, **kwargs)
|
out, meta = _profile_meta(func, *args, **kwargs)
|
||||||
# currently we set the fwd_mem_tmp of ReLU to zero
|
|
||||||
if target in [torch.nn.functional.relu]:
|
|
||||||
meta.fwd_in = []
|
|
||||||
meta.fwd_tmp = []
|
|
||||||
meta.bwd_mem_out = 0
|
|
||||||
meta.fwd_mem_tmp = 0
|
|
||||||
else:
|
else:
|
||||||
out, meta = _profile_concrete(func, *args, **kwargs)
|
out, meta = _profile_concrete(func, *args, **kwargs)
|
||||||
|
|
||||||
if inplace:
|
if inplace:
|
||||||
kwargs['inplace'] = True
|
kwargs['inplace'] = True
|
||||||
do_not_cache = False
|
do_not_cache = False
|
||||||
@ -386,20 +383,16 @@ def profile_module(module: torch.nn.Module, device: str = 'meta') -> Callable:
|
|||||||
global do_not_cache
|
global do_not_cache
|
||||||
|
|
||||||
inplace = getattr(module, 'inplace', False)
|
inplace = getattr(module, 'inplace', False)
|
||||||
if inplace or type(module) in [torch.nn.ReLU]:
|
if type(module) in OUTPUT_SAVED_MOD:
|
||||||
|
do_not_cache = True
|
||||||
|
if inplace:
|
||||||
do_not_cache = True
|
do_not_cache = True
|
||||||
module.inplace = False
|
module.inplace = False
|
||||||
if device == 'meta':
|
if device == 'meta':
|
||||||
out, meta = _profile_meta(func, *args, **kwargs)
|
out, meta = _profile_meta(func, *args, **kwargs)
|
||||||
# currently we set the fwd_tmp of ReLU to []
|
|
||||||
if type(module) in [torch.nn.ReLU]:
|
|
||||||
meta.fwd_in = []
|
|
||||||
meta.fwd_tmp = []
|
|
||||||
meta.bwd_mem_out = 0
|
|
||||||
else:
|
else:
|
||||||
out, meta = _profile_concrete(func, *args, **kwargs)
|
out, meta = _profile_concrete(func, *args, **kwargs)
|
||||||
if inplace:
|
if inplace:
|
||||||
|
|
||||||
module.inplace = True
|
module.inplace = True
|
||||||
do_not_cache = False
|
do_not_cache = False
|
||||||
|
|
||||||
|
@ -125,5 +125,5 @@ class MetaTensor(torch.Tensor):
|
|||||||
device = kwargs['device']
|
device = kwargs['device']
|
||||||
result = super().to(*args, **kwargs)
|
result = super().to(*args, **kwargs)
|
||||||
if device is not None:
|
if device is not None:
|
||||||
result = MetaTensor(deepcopy(result), fake_device=device)
|
result = MetaTensor(result, fake_device=device)
|
||||||
return result
|
return result
|
||||||
|
50
tests/test_fx/test_profiler/gpt_utils.py
Normal file
50
tests/test_fx/test_profiler/gpt_utils.py
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from transformers import GPT2Config, GPT2LMHeadModel
|
||||||
|
|
||||||
|
|
||||||
|
class GPTLMModel(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
hidden_size=768,
|
||||||
|
num_layers=12,
|
||||||
|
num_attention_heads=12,
|
||||||
|
max_seq_len=1024,
|
||||||
|
vocab_size=50257,
|
||||||
|
checkpoint=False):
|
||||||
|
super().__init__()
|
||||||
|
self.checkpoint = checkpoint
|
||||||
|
self.model = GPT2LMHeadModel(
|
||||||
|
GPT2Config(n_embd=hidden_size,
|
||||||
|
n_layer=num_layers,
|
||||||
|
n_head=num_attention_heads,
|
||||||
|
n_positions=max_seq_len,
|
||||||
|
n_ctx=max_seq_len,
|
||||||
|
vocab_size=vocab_size))
|
||||||
|
if checkpoint:
|
||||||
|
self.model.gradient_checkpointing_enable()
|
||||||
|
|
||||||
|
def forward(self, input_ids, attention_mask):
|
||||||
|
# Only return lm_logits
|
||||||
|
return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0]
|
||||||
|
|
||||||
|
|
||||||
|
class GPTLMLoss(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.loss_fn = nn.CrossEntropyLoss()
|
||||||
|
|
||||||
|
def forward(self, logits, labels):
|
||||||
|
shift_logits = logits[..., :-1, :].contiguous()
|
||||||
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
|
# Flatten the tokens
|
||||||
|
return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
||||||
|
|
||||||
|
|
||||||
|
def gpt2_medium(checkpoint=False):
|
||||||
|
return GPTLMModel(hidden_size=1024, num_layers=24, num_attention_heads=16, checkpoint=checkpoint)
|
||||||
|
|
||||||
|
|
||||||
|
def gpt2_xl(checkpoint=False):
|
||||||
|
return GPTLMModel(hidden_size=1600, num_layers=48, num_attention_heads=32, checkpoint=checkpoint)
|
181
tests/test_fx/test_profiler/test_profiler_meta_info_prop.py
Normal file
181
tests/test_fx/test_profiler/test_profiler_meta_info_prop.py
Normal file
@ -0,0 +1,181 @@
|
|||||||
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.fx
|
||||||
|
import torchvision.models as tm
|
||||||
|
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||||
|
from colossalai.fx.profiler import (calculate_fwd_out, calculate_fwd_tmp, is_compatible_with_meta, parameter_size)
|
||||||
|
from colossalai.fx.tracer.tracer import ColoTracer
|
||||||
|
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||||
|
from gpt_utils import gpt2_medium, gpt2_xl
|
||||||
|
from torch.fx import symbolic_trace
|
||||||
|
|
||||||
|
if is_compatible_with_meta():
|
||||||
|
from colossalai.fx.profiler import MetaTensor
|
||||||
|
|
||||||
|
TM_BATCH_SIZE = 64
|
||||||
|
GPT_BATCH_SIZE = 8
|
||||||
|
NUM_STEPS = 5
|
||||||
|
|
||||||
|
|
||||||
|
def extract_forward_mem(gm: torch.fx.GraphModule):
|
||||||
|
node_size = 0
|
||||||
|
param_size = 0
|
||||||
|
for node in gm.graph.nodes:
|
||||||
|
node_size += calculate_fwd_tmp(node)
|
||||||
|
node_size += calculate_fwd_out(node)
|
||||||
|
param_size = parameter_size(gm)
|
||||||
|
return (node_size + param_size) / 1024**2, param_size / 1024**2
|
||||||
|
|
||||||
|
|
||||||
|
def extract_forward_flops(gm: torch.fx.GraphModule):
|
||||||
|
fwd_flop = 0
|
||||||
|
bwd_flop = 0
|
||||||
|
for node in gm.graph.nodes:
|
||||||
|
fwd_flop += node.meta.get('fwd_flop', 0)
|
||||||
|
bwd_flop += node.meta.get('bwd_flop', 0)
|
||||||
|
return fwd_flop, bwd_flop
|
||||||
|
|
||||||
|
|
||||||
|
def gen_tm_data(batch_size: int, shape: Tuple[int, int, int], device='cuda'):
|
||||||
|
data = torch.rand(batch_size, *shape, device=device)
|
||||||
|
label = torch.empty(batch_size, dtype=torch.long, device=device).random_(1000)
|
||||||
|
return data, label
|
||||||
|
|
||||||
|
|
||||||
|
def gen_gpt_data(batch_size, seq_len, vocab_size, device='cpu'):
|
||||||
|
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
|
||||||
|
attention_mask = torch.ones_like(input_ids, device=device)
|
||||||
|
return input_ids, attention_mask
|
||||||
|
|
||||||
|
|
||||||
|
def run_tm_forward(gm: torch.fx.GraphModule):
|
||||||
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
forward_mem = -torch.cuda.memory_allocated(device="cuda:0") / 1024**2
|
||||||
|
param_mem = -torch.cuda.memory_allocated(device="cuda:0") / 1024**2
|
||||||
|
gm.cuda()
|
||||||
|
param_mem += torch.cuda.memory_allocated(device="cuda:0") / 1024**2
|
||||||
|
gm.train()
|
||||||
|
for n in range(NUM_STEPS):
|
||||||
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
data, _ = gen_tm_data(TM_BATCH_SIZE, (3, 224, 224))
|
||||||
|
|
||||||
|
# If we need to dive deep into the memory usage by
|
||||||
|
# inspecting `saved_tensor_hooks`
|
||||||
|
|
||||||
|
# =====================================================
|
||||||
|
# fwd_mem = 0
|
||||||
|
# cache = set()
|
||||||
|
# def pack(x):
|
||||||
|
# if isinstance(x, torch.Tensor):
|
||||||
|
# nonlocal fwd_mem, cache
|
||||||
|
# if x.data_ptr() not in cache:
|
||||||
|
# fwd_mem += activation_size(x)
|
||||||
|
# cache.add(x.data_ptr())
|
||||||
|
# return x
|
||||||
|
# def unpack(x):
|
||||||
|
# return x
|
||||||
|
#
|
||||||
|
# with torch.autograd.graph.saved_tensors_hooks(pack, unpack):
|
||||||
|
# output = gm(data)
|
||||||
|
# print(f'Memory estimation by saved_tensor_hooks: {fwd_mem / 1024**2}')
|
||||||
|
# =====================================================
|
||||||
|
|
||||||
|
output = gm(data)
|
||||||
|
forward_mem += torch.cuda.memory_allocated(device="cuda:0") / 1024**2 / NUM_STEPS
|
||||||
|
del output
|
||||||
|
return forward_mem, param_mem
|
||||||
|
|
||||||
|
|
||||||
|
def run_gpt_forward(gm: torch.fx.GraphModule):
|
||||||
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
forward_mem = -torch.cuda.memory_allocated(device="cuda:0") / 1024**2
|
||||||
|
param_mem = -torch.cuda.memory_allocated(device="cuda:0") / 1024**2
|
||||||
|
gm.cuda()
|
||||||
|
param_mem += torch.cuda.memory_allocated(device="cuda:0") / 1024**2
|
||||||
|
for n in range(NUM_STEPS):
|
||||||
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
data, mask = gen_gpt_data(GPT_BATCH_SIZE, 1024, 50257, device='cuda:0')
|
||||||
|
|
||||||
|
# If we need to dive deep into the memory usage by
|
||||||
|
# inspecting `saved_tensor_hooks`
|
||||||
|
|
||||||
|
# =====================================================
|
||||||
|
# fwd_mem = 0
|
||||||
|
# cache = set()
|
||||||
|
# def pack(x):
|
||||||
|
# if isinstance(x, torch.Tensor):
|
||||||
|
# nonlocal fwd_mem, cache
|
||||||
|
# if x.data_ptr() not in cache:
|
||||||
|
# fwd_mem += activation_size(x)
|
||||||
|
# cache.add(x.data_ptr())
|
||||||
|
# return x
|
||||||
|
# def unpack(x):
|
||||||
|
# return x
|
||||||
|
#
|
||||||
|
# with torch.autograd.graph.saved_tensors_hooks(pack, unpack):
|
||||||
|
# output = gm(data, mask)
|
||||||
|
# print(f'Memory estimation by saved_tensor_hooks: {fwd_mem / 1024**2}')
|
||||||
|
# =====================================================
|
||||||
|
|
||||||
|
output = gm(data, mask)
|
||||||
|
forward_mem += torch.cuda.memory_allocated(device="cuda:0") / 1024**2 / NUM_STEPS
|
||||||
|
del output
|
||||||
|
return forward_mem, param_mem
|
||||||
|
|
||||||
|
|
||||||
|
@run_on_environment_flag(name='FX_PROFILER')
|
||||||
|
def test_meta_info_prop():
|
||||||
|
for m in [
|
||||||
|
tm.alexnet, tm.resnet18, tm.resnet34, tm.resnet50, tm.resnet101, tm.resnet152, tm.densenet121,
|
||||||
|
tm.densenet161, tm.densenet169, tm.densenet201, tm.convnext_tiny, tm.convnext_small, tm.convnext_base,
|
||||||
|
tm.convnext_large, tm.wide_resnet50_2, tm.wide_resnet101_2, tm.regnet_x_16gf, tm.mnasnet0_5,
|
||||||
|
tm.efficientnet_b0, tm.shufflenet_v2_x0_5, tm.shufflenet_v2_x1_0, tm.shufflenet_v2_x1_5,
|
||||||
|
tm.shufflenet_v2_x2_0, tm.mobilenet_v2, tm.mobilenet_v3_small, tm.mobilenet_v3_large, tm.resnext50_32x4d,
|
||||||
|
tm.resnext101_32x8d, tm.resnext101_64x4d, tm.vit_b_16, tm.vit_b_32, tm.vit_h_14, tm.vit_l_16, tm.vit_l_32,
|
||||||
|
tm.vgg11, tm.vgg11_bn, tm.vgg13, tm.vgg13_bn, tm.vgg16, tm.vgg16_bn, tm.vgg19, tm.vgg19_bn
|
||||||
|
]:
|
||||||
|
model = m().cuda()
|
||||||
|
model.train()
|
||||||
|
data = MetaTensor(torch.rand(int(TM_BATCH_SIZE), 3, 224, 224, device='meta'), fake_device='cuda:0')
|
||||||
|
gm = symbolic_trace(model)
|
||||||
|
interp = MetaInfoProp(gm)
|
||||||
|
interp.propagate(data)
|
||||||
|
gm.cpu()
|
||||||
|
|
||||||
|
meta_forward_mem, meta_param_mem = extract_forward_mem(gm)
|
||||||
|
fwd_flop, bwd_flop = extract_forward_flops(gm)
|
||||||
|
concrete_forward_mem, concrete_param_mem = run_tm_forward(gm)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f'|{m.__name__}|{meta_forward_mem:.3f} MB|{meta_param_mem:.3f} MB|{concrete_forward_mem:.3f} MB|{concrete_param_mem:.3f} MB|fwd_flop={fwd_flop / 1e9:.3f}GFLOPs|bwd_flop={bwd_flop / 1e9:.3f}GFLOPs|'
|
||||||
|
)
|
||||||
|
del model, gm
|
||||||
|
|
||||||
|
|
||||||
|
@run_on_environment_flag(name='FX_PROFILER')
|
||||||
|
def test_gpt_meta_info_prop():
|
||||||
|
for m in [gpt2_medium]:
|
||||||
|
model = m().cuda()
|
||||||
|
model.train()
|
||||||
|
data, mask = gen_gpt_data(GPT_BATCH_SIZE, 1024, 50257, device='meta')
|
||||||
|
graph = ColoTracer().trace(model, meta_args={'input_ids': data, 'attention_mask': mask})
|
||||||
|
gm = torch.fx.GraphModule(model, graph)
|
||||||
|
interp = MetaInfoProp(gm)
|
||||||
|
interp.propagate(MetaTensor(data, fake_device='cuda:0'), MetaTensor(mask, fake_device='cuda:0'))
|
||||||
|
model.cpu()
|
||||||
|
|
||||||
|
fwd_flop, bwd_flop = extract_forward_flops(gm)
|
||||||
|
|
||||||
|
concrete_forward_mem, concrete_param_mem = run_gpt_forward(gm)
|
||||||
|
meta_forward_mem, meta_param_mem = extract_forward_mem(gm)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f'|{m.__name__}|{meta_forward_mem:.3f} MB|{meta_param_mem:.3f} MB|{concrete_forward_mem:.3f} MB|{concrete_param_mem:.3f} MB|fwd_flop={fwd_flop / 1e9:.3f}GFLOPs|bwd_flop={bwd_flop / 1e9:.3f}GFLOPs|'
|
||||||
|
)
|
||||||
|
del model, gm
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_meta_info_prop()
|
||||||
|
test_gpt_meta_info_prop()
|
Loading…
Reference in New Issue
Block a user