[fx] added torchvision model tracing testing (#1216)

* [fx] added torchvision model tracing testing

* remove unused imports
This commit is contained in:
Frank Lee
2022-07-06 21:37:56 +08:00
committed by GitHub
parent 52736205d9
commit 11973d892d
6 changed files with 346 additions and 71 deletions

View File

@@ -1,31 +0,0 @@
import torch
import torch.nn
def test_maxpool():
layer_to_test = dict(maxpool_1d=dict(layer=torch.nn.MaxPool1d, shape=(4, 3, 4)),
maxpool_2d=dict(layer=torch.nn.MaxPool2d, shape=(4, 3, 4, 4)))
for name, info in layer_to_test.items():
data = torch.rand(*info['shape'])
meta_data = data.to('meta')
layer = info['layer'](kernel_size=3)
out = layer(data)
meta_out = layer(meta_data)
assert meta_out.is_meta
assert out.shape == meta_out.shape
def test_avgpool():
layer_to_test = dict(maxpool_1d=dict(layer=torch.nn.AvgPool1d, shape=(4, 3, 4)),
maxpool_2d=dict(layer=torch.nn.AvgPool2d, shape=(4, 3, 4, 4)),
maxpool_3d=dict(layer=torch.nn.AvgPool3d, shape=(4, 3, 4, 4, 4)))
for name, info in layer_to_test.items():
data = torch.rand(*info['shape'])
meta_data = data.to('meta')
layer = info['layer'](kernel_size=3)
out = layer(data)
meta_out = layer(meta_data)
assert meta_out.is_meta
assert out.shape == meta_out.shape

View File

@@ -227,31 +227,88 @@ def test_conv3d():
output_shape=materialized_output.shape)
def test_maxpool3d():
pooler = torch.nn.MaxPool3d(kernel_size=3)
def test_pool1d():
combinations = [[torch.nn.MaxPool1d, patched_module.torch_nn_maxpool1d],
[torch.nn.AvgPool1d, patched_module.torch_nn_avgpool1d]]
# test max pool 3d
data = torch.rand(2, 3, 4, 4, 4)
materialized_output = pooler(data)
_assert_output_shape(data=data,
module=pooler,
patch_fn=patched_module.torch_nn_maxpool3d,
expect_exception=False,
output_shape=materialized_output.shape)
for (layer_cls, patch_func) in combinations:
pooler = layer_cls(kernel_size=3)
# test max pool 3d
data = torch.rand(2, 3, 4, 4)
materialized_output = pooler(data)
_assert_output_shape(data=data,
module=pooler,
patch_fn=patched_module.torch_nn_maxpool3d,
expect_exception=False,
output_shape=materialized_output.shape)
data = torch.rand(2, 3, 4)
materialized_output = pooler(data)
_assert_output_shape(data=data,
module=pooler,
patch_fn=patch_func,
expect_exception=False,
output_shape=materialized_output.shape)
# test max pool 3d
data = torch.rand(2, 3, 4)
_assert_output_shape(data=data,
module=pooler,
patch_fn=patched_module.torch_nn_maxpool3d,
expect_exception=True,
output_shape=None)
data = torch.rand(2, 4)
materialized_output = pooler(data)
_assert_output_shape(data=data,
module=pooler,
patch_fn=patch_func,
expect_exception=False,
output_shape=materialized_output.shape)
data = torch.rand(2, 3, 4, 4)
_assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=True, output_shape=None)
def test_pool2d():
combinations = [[torch.nn.MaxPool2d, patched_module.torch_nn_maxpool2d],
[torch.nn.AvgPool2d, patched_module.torch_nn_avgpool2d]]
for (layer_cls, patch_func) in combinations:
pooler = layer_cls(kernel_size=3)
# test max pool 3d
data = torch.rand(2, 3, 4, 4)
materialized_output = pooler(data)
_assert_output_shape(data=data,
module=pooler,
patch_fn=patch_func,
expect_exception=False,
output_shape=materialized_output.shape)
# test max pool 3d
data = torch.rand(2, 4, 4)
materialized_output = pooler(data)
_assert_output_shape(data=data,
module=pooler,
patch_fn=patch_func,
expect_exception=False,
output_shape=materialized_output.shape)
# test max pool 3d
data = torch.rand(2, 3, 4, 4, 4)
_assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=True, output_shape=None)
def test_pool3d():
combinations = [[torch.nn.MaxPool3d, patched_module.torch_nn_maxpool3d],
[torch.nn.AvgPool3d, patched_module.torch_nn_avgpool3d]]
for (layer_cls, patch_func) in combinations:
pooler = layer_cls(kernel_size=3)
# test max pool 3d
data = torch.rand(2, 3, 4, 4, 4)
materialized_output = pooler(data)
_assert_output_shape(data=data,
module=pooler,
patch_fn=patch_func,
expect_exception=False,
output_shape=materialized_output.shape)
# test max pool 3d
data = torch.rand(2, 4, 4, 4)
materialized_output = pooler(data)
_assert_output_shape(data=data,
module=pooler,
patch_fn=patch_func,
expect_exception=False,
output_shape=materialized_output.shape)
# test max pool 3d
data = torch.rand(2, 3, 4)
_assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=True, output_shape=None)

View File

@@ -0,0 +1,46 @@
import torch
import pytest
try:
import torchvision.models as tm
except:
pass
from colossalai.fx import ColoTracer
from torch.fx import GraphModule
@pytest.mark.skip('skip as torchvision is required')
def test_torchvision_models():
MODEL_LIST = [
tm.vgg11, tm.resnet18, tm.densenet121, tm.mobilenet_v3_small, tm.resnext50_32x4d, tm.wide_resnet50_2,
tm.regnet_x_16gf, tm.vit_b_16, tm.convnext_small, tm.mnasnet0_5, tm.efficientnet_b0
]
torch.backends.cudnn.deterministic = True
tracer = ColoTracer()
data = torch.rand(2, 3, 224, 224)
for model_cls in MODEL_LIST:
if model_cls in [tm.convnext_small, tm.efficientnet_b0]:
# remove the impact of randomicity
model = model_cls(stochastic_depth_prob=0)
else:
model = model_cls()
graph = tracer.trace(root=model)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
model.eval()
gm.eval()
with torch.no_grad():
fx_out = gm(data)
non_fx_out = model(data)
assert torch.allclose(
fx_out, non_fx_out), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
if __name__ == '__main__':
test_torchvision_models()