mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 03:52:01 +00:00
[analyzer] a minimal implementation of static graph analyzer (#2852)
* [hotfix] meta tensor default device. * [siu] add experimental submodules to main branch. * [siu] * [siu] * [analyzer] init. * [analyzer] readme. * [analyzer] readme. * [analyzer] readme. * [analyzer] readme. * [test] add test. * Update symbolic_trace.py * mark skip tests. * try except. * try except. * try except. * s * init * init * fix * skip * skip --------- Co-authored-by: Daniel Shao <superdainiu@MININT-PVARVID.fareast.corp.microsoft.com> Co-authored-by: Daniel Shao <superdainiu@Daniels-Mac.local>
This commit is contained in:
0
tests/test_analyzer/test_fx/__init__.py
Normal file
0
tests/test_analyzer/test_fx/__init__.py
Normal file
113
tests/test_analyzer/test_fx/test_bias_addition.py
Normal file
113
tests/test_analyzer/test_fx/test_bias_addition.py
Normal file
@@ -0,0 +1,113 @@
|
||||
import pytest
|
||||
import torch
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
try:
|
||||
from colossalai._analyzer.fx import symbolic_trace
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
class LinearModel(torch.nn.Module):
|
||||
|
||||
def __init__(self, in_features, out_features, bias):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(in_features, out_features, bias=bias)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
class ConvModel(torch.nn.Module):
|
||||
|
||||
def __init__(self, in_channel, out_channels, kernel_size, bias) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(in_channel,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
bias=bias,
|
||||
padding=1,
|
||||
stride=2,
|
||||
dilation=2,
|
||||
groups=3)
|
||||
self.conv_transpose = torch.nn.ConvTranspose2d(in_channel,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
bias=bias,
|
||||
padding=1,
|
||||
stride=2,
|
||||
dilation=2,
|
||||
groups=3)
|
||||
|
||||
def forward(self, x, select=0):
|
||||
if select == 0:
|
||||
x = self.conv(x)
|
||||
else:
|
||||
x = self.conv_transpose(x)
|
||||
return x
|
||||
|
||||
|
||||
class SiuModel(torch.nn.Module):
|
||||
|
||||
def __init__(self, bias) -> None:
|
||||
super().__init__()
|
||||
self.linear = LinearModel(3, 3, bias)
|
||||
self.conv = ConvModel(3, 6, 3, bias)
|
||||
|
||||
def forward(self, x, select=0):
|
||||
x = self.linear(x)
|
||||
x = checkpoint(self.conv, x, select)
|
||||
return x
|
||||
|
||||
|
||||
class AddmmModel(torch.nn.Module):
|
||||
|
||||
def __init__(self, alpha, beta) -> None:
|
||||
super().__init__()
|
||||
self.alpha = alpha
|
||||
self.beta = beta
|
||||
|
||||
def forward(self, x):
|
||||
x = torch.addmm(x, x, x, alpha=self.alpha, beta=self.beta)
|
||||
return x
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
|
||||
@pytest.mark.parametrize("bias", [True, False])
|
||||
@pytest.mark.parametrize("bias_addition_split", [True, False])
|
||||
@pytest.mark.parametrize("shape", [(3, 3, 3), (3, 3, 3, 3)])
|
||||
@pytest.mark.parametrize("select", [0, 1])
|
||||
def test_siu_model(bias, bias_addition_split, shape, select):
|
||||
model = SiuModel(bias=bias)
|
||||
x = torch.rand(shape)
|
||||
gm = symbolic_trace(model,
|
||||
meta_args={'x': x},
|
||||
concrete_args={'select': select},
|
||||
trace_act_ckpt=True,
|
||||
bias_addition_split=bias_addition_split)
|
||||
assert torch.allclose(model(x, select), gm(x, select)), 'original model and traced model should be the same!'
|
||||
if bias and bias_addition_split:
|
||||
assert '+' in gm.code, 'bias addition should be split!'
|
||||
else:
|
||||
assert '+' not in gm.code, 'bias addition should not be split!'
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
|
||||
@pytest.mark.parametrize("alpha", [1, 2])
|
||||
@pytest.mark.parametrize("beta", [1, 2])
|
||||
@pytest.mark.parametrize("bias_addition_split", [True, False])
|
||||
@pytest.mark.parametrize("shape", [(3, 3), (5, 5)])
|
||||
def test_addmm_model(alpha, beta, bias_addition_split, shape):
|
||||
model = AddmmModel(alpha=alpha, beta=beta)
|
||||
x = torch.rand(shape)
|
||||
gm = symbolic_trace(model, meta_args={'x': x}, trace_act_ckpt=True, bias_addition_split=bias_addition_split)
|
||||
assert torch.allclose(model(x), gm(x)), 'original model and traced model should be the same!'
|
||||
if (alpha == 1 and beta == 1) or not bias_addition_split:
|
||||
assert '*' not in gm.code, 'bias addition should not be split!'
|
||||
elif bias_addition_split:
|
||||
assert '+' in gm.code, 'bias addition should be split!'
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_siu_model(True, True, (3, 3, 3))
|
78
tests/test_analyzer/test_fx/test_mod_dir.py
Normal file
78
tests/test_analyzer/test_fx/test_mod_dir.py
Normal file
@@ -0,0 +1,78 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
try:
|
||||
from colossalai._analyzer.fx import symbolic_trace
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
class LinearModel(torch.nn.Module):
|
||||
|
||||
def __init__(self, in_features, out_features, bias):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(in_features, out_features, bias=bias)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
class ConvModel(torch.nn.Module):
|
||||
|
||||
def __init__(self, in_channel, out_channels, kernel_size, bias) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(in_channel,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
bias=bias,
|
||||
padding=1,
|
||||
stride=2,
|
||||
dilation=2,
|
||||
groups=3)
|
||||
self.conv_transpose = torch.nn.ConvTranspose2d(out_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
bias=bias,
|
||||
padding=1,
|
||||
stride=2,
|
||||
dilation=2,
|
||||
groups=3)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.conv_transpose(x)
|
||||
return x
|
||||
|
||||
|
||||
class AModel(torch.nn.Module):
|
||||
|
||||
def __init__(self, bias) -> None:
|
||||
super().__init__()
|
||||
self.linear_1 = LinearModel(3, 3, bias)
|
||||
self.linear_2 = LinearModel(3, 3, bias)
|
||||
self.conv = ConvModel(3, 6, 3, bias)
|
||||
|
||||
def forward(self, x):
|
||||
for i in range(x.shape[0]):
|
||||
x = self.linear_1(x)
|
||||
x = self.linear_2(x)
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
|
||||
@pytest.mark.parametrize("bias", [True, False])
|
||||
@pytest.mark.parametrize("bias_addition_split", [True, False])
|
||||
@pytest.mark.parametrize("shape", [(3, 3, 3), (3, 3, 3, 3)])
|
||||
def test_mod_dir(bias, bias_addition_split, shape):
|
||||
model = AModel(bias=bias)
|
||||
x = torch.rand(shape)
|
||||
gm = symbolic_trace(model, meta_args={'x': x}, bias_addition_split=bias_addition_split)
|
||||
for node in gm.graph.nodes:
|
||||
assert len(node.meta['info'].mod_dir), f"{node} should have non-trivial ``mod_dir``."
|
||||
print(node, node.meta['info'].mod_dir)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_mod_dir(True, True, (3, 3, 3))
|
55
tests/test_analyzer/test_fx/test_nested_ckpt.py
Normal file
55
tests/test_analyzer/test_fx/test_nested_ckpt.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
import pytest
|
||||
|
||||
try:
|
||||
from colossalai._analyzer.fx import symbolic_trace
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
class MyModule(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.a = nn.Linear(10, 10)
|
||||
self.b = nn.Linear(10, 10)
|
||||
self.c = nn.Linear(10, 10)
|
||||
self.d = nn.Linear(10, 10)
|
||||
self.e = nn.Linear(10, 10)
|
||||
|
||||
def checkpoint_0(self, x):
|
||||
return checkpoint(self.checkpoint_0_0, x) + checkpoint(self.checkpoint_0_1, x) + self.e(x)
|
||||
|
||||
def checkpoint_0_0(self, x):
|
||||
return checkpoint(self.checkpoint_0_0_0, x) + checkpoint(self.checkpoint_0_0_1, x)
|
||||
|
||||
def checkpoint_0_0_0(self, x):
|
||||
return self.a(x) + checkpoint(self.checkpoint_0_0_0_0, x, use_reentrant=False)
|
||||
|
||||
def checkpoint_0_0_0_0(self, x):
|
||||
return self.b(x)
|
||||
|
||||
def checkpoint_0_0_1(self, x):
|
||||
return self.b(x) + self.c(x)
|
||||
|
||||
def checkpoint_0_1(self, x):
|
||||
return self.d(x)
|
||||
|
||||
def forward(self, x):
|
||||
return checkpoint(self.checkpoint_0, x)
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
|
||||
def test_nested_ckpt():
|
||||
model = MyModule()
|
||||
x = torch.rand(10, 10)
|
||||
gm = symbolic_trace(model, meta_args={'x': x}, trace_act_ckpt=True)
|
||||
assert torch.allclose(gm(x), model(x)), "The traced model should generate the same output as the original model."
|
||||
for ckpt_def in filter(lambda s: s.startswith('checkpoint'), dir(model)):
|
||||
assert ckpt_def in gm.code, f"Checkpoint {ckpt_def} should be in the traced code.\n Traced code = {gm.code}"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_nested_ckpt()
|
63
tests/test_analyzer/test_fx/test_shape_prop.py
Normal file
63
tests/test_analyzer/test_fx/test_shape_prop.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import pytest
|
||||
import timm.models as tmm
|
||||
import torch
|
||||
import torchvision.models as tm
|
||||
from .zoo import tm_models, tmm_models
|
||||
|
||||
try:
|
||||
from colossalai._analyzer._subclasses import MetaTensorMode
|
||||
from colossalai._analyzer.fx import symbolic_trace
|
||||
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
|
||||
from colossalai._analyzer.fx.symbolic_profile import register_shape_impl
|
||||
|
||||
|
||||
@register_shape_impl(torch.nn.functional.linear)
|
||||
def linear_impl(*args, **kwargs):
|
||||
assert True
|
||||
return torch.nn.functional.linear(*args, **kwargs)
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
def _check_gm_validity(gm: torch.fx.GraphModule):
|
||||
for node in gm.graph.nodes:
|
||||
assert node.meta['info'].outputs, f'In {gm.__class__.__name__}, {node} has no output shape.'
|
||||
if node.op in [
|
||||
# 'call_module', # can apply to params
|
||||
# 'call_function', # can apply to params
|
||||
# 'call_method', # can apply to params
|
||||
]:
|
||||
assert node.meta['info'].inputs, f'In {gm.__class__.__name__}, {node} has no input shape.'
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
|
||||
@pytest.mark.parametrize('m', tm_models)
|
||||
def test_torchvision_shape_prop(m):
|
||||
with MetaTensorMode():
|
||||
model = m()
|
||||
data = torch.rand(100, 3, 224, 224)
|
||||
meta_args = {
|
||||
"x": data,
|
||||
}
|
||||
gm = symbolic_trace(model, meta_args=meta_args)
|
||||
shape_prop_pass(gm, data)
|
||||
_check_gm_validity(gm)
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
|
||||
@pytest.mark.parametrize('m', tmm_models)
|
||||
def test_timm_shape_prop(m):
|
||||
with MetaTensorMode():
|
||||
model = m()
|
||||
data = torch.rand(100, 3, 224, 224)
|
||||
meta_args = {
|
||||
"x": data,
|
||||
}
|
||||
gm = symbolic_trace(model, meta_args=meta_args)
|
||||
shape_prop_pass(gm, data)
|
||||
_check_gm_validity(gm)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_torchvision_shape_prop(tm.resnet18)
|
||||
test_timm_shape_prop(tmm.vgg11)
|
49
tests/test_analyzer/test_fx/test_symbolic_profile.py
Normal file
49
tests/test_analyzer/test_fx/test_symbolic_profile.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import pytest
|
||||
import timm.models as tmm
|
||||
import torch
|
||||
import torchvision.models as tm
|
||||
from .zoo import tm_models, tmm_models
|
||||
|
||||
try:
|
||||
from colossalai._analyzer._subclasses import MetaTensorMode
|
||||
from colossalai._analyzer.fx import symbolic_profile, symbolic_trace
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
def _check_gm_validity(gm: torch.fx.GraphModule):
|
||||
for node in gm.graph.nodes:
|
||||
assert len(node.meta['info'].global_ctx), f'In {gm.__class__.__name__}, {node} has empty global context.'
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
|
||||
@pytest.mark.parametrize('m', tm_models)
|
||||
def test_torchvision_profile(m, verbose=False, bias_addition_split=False):
|
||||
with MetaTensorMode():
|
||||
model = m()
|
||||
data = torch.rand(8, 3, 224, 224)
|
||||
meta_args = {
|
||||
"x": data,
|
||||
}
|
||||
gm = symbolic_trace(model, meta_args=meta_args, bias_addition_split=bias_addition_split)
|
||||
symbolic_profile(gm, data, verbose=verbose)
|
||||
_check_gm_validity(gm)
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
|
||||
@pytest.mark.parametrize('m', tmm_models)
|
||||
def test_timm_profile(m, verbose=False, bias_addition_split=False):
|
||||
with MetaTensorMode():
|
||||
model = m()
|
||||
data = torch.rand(8, 3, 224, 224)
|
||||
meta_args = {
|
||||
"x": data,
|
||||
}
|
||||
gm = symbolic_trace(model, meta_args=meta_args, bias_addition_split=bias_addition_split)
|
||||
symbolic_profile(gm, data, verbose=verbose)
|
||||
_check_gm_validity(gm)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_torchvision_profile(tm.vit_b_16, verbose=True, bias_addition_split=False)
|
||||
test_timm_profile(tmm.gmlp_b16_224, verbose=True, bias_addition_split=False)
|
53
tests/test_analyzer/test_fx/zoo.py
Normal file
53
tests/test_analyzer/test_fx/zoo.py
Normal file
@@ -0,0 +1,53 @@
|
||||
import timm.models as tmm
|
||||
import torchvision.models as tm
|
||||
|
||||
# input shape: (batch_size, 3, 224, 224)
|
||||
tm_models = [
|
||||
tm.alexnet,
|
||||
tm.convnext_base,
|
||||
tm.densenet121,
|
||||
# tm.efficientnet_v2_s,
|
||||
# tm.googlenet, # output bad case
|
||||
# tm.inception_v3, # bad case
|
||||
tm.mobilenet_v2,
|
||||
tm.mobilenet_v3_small,
|
||||
tm.mnasnet0_5,
|
||||
tm.resnet18,
|
||||
tm.regnet_x_16gf,
|
||||
tm.resnext50_32x4d,
|
||||
tm.shufflenet_v2_x0_5,
|
||||
tm.squeezenet1_0,
|
||||
# tm.swin_s, # fx bad case
|
||||
tm.vgg11,
|
||||
tm.vit_b_16,
|
||||
tm.wide_resnet50_2,
|
||||
]
|
||||
|
||||
tmm_models = [
|
||||
tmm.beit_base_patch16_224,
|
||||
tmm.beitv2_base_patch16_224,
|
||||
tmm.cait_s24_224,
|
||||
tmm.coat_lite_mini,
|
||||
tmm.convit_base,
|
||||
tmm.deit3_base_patch16_224,
|
||||
tmm.dm_nfnet_f0,
|
||||
tmm.eca_nfnet_l0,
|
||||
tmm.efficientformer_l1,
|
||||
tmm.ese_vovnet19b_dw,
|
||||
tmm.gmixer_12_224,
|
||||
tmm.gmlp_b16_224,
|
||||
tmm.hardcorenas_a,
|
||||
tmm.hrnet_w18_small,
|
||||
tmm.inception_v3,
|
||||
tmm.mixer_b16_224,
|
||||
tmm.nf_ecaresnet101,
|
||||
tmm.nf_regnet_b0,
|
||||
# tmm.pit_b_224, # pretrained only
|
||||
tmm.regnetv_040,
|
||||
tmm.skresnet18,
|
||||
# tmm.swin_base_patch4_window7_224, # fx bad case
|
||||
# tmm.tnt_b_patch16_224, # bad case
|
||||
tmm.vgg11,
|
||||
tmm.vit_base_patch16_18x2_224,
|
||||
tmm.wide_resnet50_2,
|
||||
]
|
0
tests/test_analyzer/test_subclasses/__init__.py
Normal file
0
tests/test_analyzer/test_subclasses/__init__.py
Normal file
82
tests/test_analyzer/test_subclasses/test_aten.py
Normal file
82
tests/test_analyzer/test_subclasses/test_aten.py
Normal file
@@ -0,0 +1,82 @@
|
||||
from typing import Any, Callable, Union
|
||||
import pytest
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
try:
|
||||
from colossalai._analyzer._subclasses import MetaTensor
|
||||
except:
|
||||
pass
|
||||
|
||||
aten = torch.ops.aten
|
||||
|
||||
registered_meta = {
|
||||
('aten.convolution.default', True): [ # (aten ops, requires_backward)
|
||||
(nn.Conv1d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2), torch.rand(2, 3, 4)),
|
||||
(nn.Conv2d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2), torch.rand(2, 3, 4, 4)),
|
||||
(nn.Conv3d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2), torch.rand(2, 3, 4, 4, 4)),
|
||||
(nn.ConvTranspose1d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2), torch.rand(2, 3, 4)),
|
||||
(nn.ConvTranspose2d(in_channels=3, out_channels=4, kernel_size=2, padding=1,
|
||||
dilation=2), torch.rand(2, 3, 4, 4)),
|
||||
(nn.ConvTranspose3d(in_channels=3, out_channels=4, kernel_size=2, padding=1,
|
||||
dilation=2), torch.rand(2, 3, 4, 4, 4)),
|
||||
],
|
||||
('aten.native_batch_norm.default', True): [
|
||||
(nn.BatchNorm1d(4), torch.rand(2, 4)),
|
||||
(nn.BatchNorm2d(4), torch.rand(1, 4, 4, 4)),
|
||||
(nn.BatchNorm3d(4), torch.rand(1, 4, 4, 4, 4)),
|
||||
],
|
||||
('aten.native_layer_norm.default', True): [(nn.LayerNorm(4), torch.rand(1, 2, 3, 4)),],
|
||||
('aten.avg_pool1d.default', True): [
|
||||
(nn.MaxPool1d(3, stride=2), torch.rand(4, 5, 5)),
|
||||
(nn.AvgPool1d(3, stride=2), torch.rand(4, 5, 5)),
|
||||
(nn.AdaptiveMaxPool1d(3), torch.rand(4, 5, 5)),
|
||||
(nn.AdaptiveAvgPool1d(3), torch.rand(4, 5, 5)),
|
||||
],
|
||||
('aten.avg_pool2d.default', True): [
|
||||
(nn.MaxPool2d((3, 2), stride=(2, 1)), torch.rand(2, 4, 5, 5)),
|
||||
(nn.AvgPool2d((3, 2), stride=(2, 1)), torch.rand(2, 4, 5, 5)),
|
||||
(nn.AdaptiveMaxPool2d((3, 2)), torch.rand(2, 4, 5, 5)),
|
||||
(nn.AdaptiveAvgPool2d((3, 2)), torch.rand(2, 4, 5, 5)),
|
||||
],
|
||||
('aten.relu.default', True): [
|
||||
(nn.ReLU(), torch.rand(4, 3, 1, 2)),
|
||||
(nn.LeakyReLU(), torch.rand(4, 3, 1, 2)),
|
||||
(nn.SiLU(), torch.rand(4, 3, 1, 2)),
|
||||
(nn.GELU(), torch.rand(4, 3, 1, 2)),
|
||||
(nn.ELU(), torch.rand(4, 3, 1, 2)),
|
||||
(nn.Sigmoid(), torch.rand(4, 3, 1, 2)),
|
||||
(nn.Tanh(), torch.rand(4, 3, 1, 2)),
|
||||
(nn.Hardswish(), torch.rand(4, 3, 1, 2)),
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
def compare_all(tensor: torch.Tensor, meta_tensor: torch.Tensor) -> Any:
|
||||
assert tensor.shape == meta_tensor.shape, f'the shape of tensor ({tensor.shape}) and meta tensor ({meta_tensor.shape}) does not match.'
|
||||
assert tensor.dtype == meta_tensor.dtype, f'the dtype of tensor ({tensor.dtype}) and meta tensor ({meta_tensor.dtype}) does not match.'
|
||||
assert tensor.stride() == meta_tensor.stride(
|
||||
), f'the stride of tensor ({tensor.stride()}) and meta tensor ({meta_tensor.stride()}) does not match.'
|
||||
|
||||
|
||||
def run_and_compare(f: Union[nn.Module, Callable], x: torch.Tensor, requires_backward=False) -> Any:
|
||||
x.requires_grad = requires_backward
|
||||
meta_x = MetaTensor(x)
|
||||
x_out, meta_out = f(x), f(meta_x)
|
||||
compare_all(x_out, meta_out)
|
||||
if requires_backward:
|
||||
x_out.sum().backward()
|
||||
meta_out.sum().backward()
|
||||
compare_all(x.grad, meta_x.grad)
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
|
||||
def test_meta_aten():
|
||||
for (aten_op, requires_backward), v in registered_meta.items():
|
||||
for f, x in v:
|
||||
run_and_compare(f, x, requires_backward)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_meta_aten()
|
50
tests/test_analyzer/test_subclasses/test_flop_tensor.py
Normal file
50
tests/test_analyzer/test_subclasses/test_flop_tensor.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision.models as tm
|
||||
from .zoo import tm_models, tmm_models
|
||||
|
||||
try:
|
||||
from colossalai._analyzer._subclasses import MetaTensorMode, flop_count
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
|
||||
@pytest.mark.parametrize('m', tm_models + tmm_models)
|
||||
def test_flop_count_module(m):
|
||||
x = torch.rand(2, 3, 224, 224)
|
||||
with MetaTensorMode(): # save time for testing
|
||||
module = m()
|
||||
rs_fwd, rs_bwd = flop_count(module, x, verbose=True)
|
||||
assert rs_fwd > 0, f'fwd flop count of {m.__name__} is {rs_fwd}'
|
||||
assert rs_bwd > 0, f'bwd flop count of {m.__name__} is {rs_bwd}'
|
||||
|
||||
|
||||
odd_cases = [
|
||||
(F.relu, (torch.rand(2, 3, 224, 224, requires_grad=True),), {
|
||||
'inplace': True
|
||||
}),
|
||||
(F.max_pool2d, (torch.rand(2, 3, 224, 224, requires_grad=True),), {
|
||||
'kernel_size': 3,
|
||||
'stride': 2,
|
||||
'padding': 1,
|
||||
'dilation': 2
|
||||
}),
|
||||
(torch.where, (torch.rand(2, 3, 224, 224) > 0.5, torch.rand(2, 3, 224, 224, requires_grad=True),
|
||||
torch.rand(2, 3, 224, 224, requires_grad=True)), {}),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
|
||||
@pytest.mark.parametrize('func, args, kwargs', odd_cases)
|
||||
def test_flop_count_function(func, args, kwargs):
|
||||
rs_fwd, rs_bwd = flop_count(func, *args, **kwargs, verbose=True)
|
||||
assert rs_fwd > 0, f'fwd flop count of {func.__name__} is {rs_fwd}'
|
||||
assert rs_bwd > 0, f'bwd flop count of {func.__name__} is {rs_bwd}'
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_flop_count_module(tm.resnet18, torch.rand(2, 3, 224, 224))
|
||||
test_flop_count_function(F.relu, (torch.rand(2, 3, 224, 224, requires_grad=True),), {'inplace': True})
|
38
tests/test_analyzer/test_subclasses/test_meta_mode.py
Normal file
38
tests/test_analyzer/test_subclasses/test_meta_mode.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torchvision.models as tm
|
||||
try:
|
||||
from colossalai._analyzer._subclasses import MetaTensor, MetaTensorMode
|
||||
except:
|
||||
pass
|
||||
from .zoo import tm_models, tmm_models
|
||||
|
||||
|
||||
def compare_all(tensor: torch.Tensor, meta_tensor: torch.Tensor):
|
||||
assert tensor.shape == meta_tensor.shape, f'the shape of tensor ({tensor.shape}) and meta tensor ({meta_tensor.shape}) does not match.'
|
||||
assert tensor.dtype == meta_tensor.dtype, f'the dtype of tensor ({tensor.dtype}) and meta tensor ({meta_tensor.dtype}) does not match.'
|
||||
assert tensor.stride() == meta_tensor.stride(
|
||||
), f'the stride of tensor ({tensor.stride()}) and meta tensor ({meta_tensor.stride()}) does not match.'
|
||||
|
||||
|
||||
def run_and_compare(model):
|
||||
x = torch.rand(2, 3, 224, 224, requires_grad=True)
|
||||
x_out = model(x)
|
||||
with MetaTensorMode():
|
||||
meta_x = torch.rand(2, 3, 224, 224, requires_grad=True)
|
||||
meta_out = model(meta_x)
|
||||
compare_all(x_out, meta_out)
|
||||
x_out.sum().backward()
|
||||
meta_out.sum().backward()
|
||||
compare_all(x.grad, meta_x.grad)
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
|
||||
@pytest.mark.parametrize('m', tm_models + tmm_models)
|
||||
def test_meta_mode_shape(m):
|
||||
run_and_compare(m())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_meta_mode_shape(tm.resnet18)
|
53
tests/test_analyzer/test_subclasses/zoo.py
Normal file
53
tests/test_analyzer/test_subclasses/zoo.py
Normal file
@@ -0,0 +1,53 @@
|
||||
import timm.models as tmm
|
||||
import torchvision.models as tm
|
||||
|
||||
# input shape: (batch_size, 3, 224, 224)
|
||||
tm_models = [
|
||||
tm.alexnet,
|
||||
tm.convnext_base,
|
||||
tm.densenet121,
|
||||
# tm.efficientnet_v2_s,
|
||||
# tm.googlenet, # output bad case
|
||||
# tm.inception_v3, # bad case
|
||||
tm.mobilenet_v2,
|
||||
tm.mobilenet_v3_small,
|
||||
tm.mnasnet0_5,
|
||||
tm.resnet18,
|
||||
tm.regnet_x_16gf,
|
||||
tm.resnext50_32x4d,
|
||||
tm.shufflenet_v2_x0_5,
|
||||
tm.squeezenet1_0,
|
||||
# tm.swin_s, # fx bad case
|
||||
tm.vgg11,
|
||||
tm.vit_b_16,
|
||||
tm.wide_resnet50_2,
|
||||
]
|
||||
|
||||
tmm_models = [
|
||||
tmm.beit_base_patch16_224,
|
||||
tmm.beitv2_base_patch16_224,
|
||||
tmm.cait_s24_224,
|
||||
tmm.coat_lite_mini,
|
||||
tmm.convit_base,
|
||||
tmm.deit3_base_patch16_224,
|
||||
tmm.dm_nfnet_f0,
|
||||
tmm.eca_nfnet_l0,
|
||||
tmm.efficientformer_l1,
|
||||
tmm.ese_vovnet19b_dw,
|
||||
tmm.gmixer_12_224,
|
||||
tmm.gmlp_b16_224,
|
||||
tmm.hardcorenas_a,
|
||||
tmm.hrnet_w18_small,
|
||||
tmm.inception_v3,
|
||||
tmm.mixer_b16_224,
|
||||
tmm.nf_ecaresnet101,
|
||||
tmm.nf_regnet_b0,
|
||||
# tmm.pit_b_224, # pretrained only
|
||||
tmm.regnetv_040,
|
||||
tmm.skresnet18,
|
||||
# tmm.swin_base_patch4_window7_224, # fx bad case
|
||||
# tmm.tnt_b_patch16_224, # bad case
|
||||
tmm.vgg11,
|
||||
tmm.vit_base_patch16_18x2_224,
|
||||
tmm.wide_resnet50_2,
|
||||
]
|
Reference in New Issue
Block a user