mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 09:07:51 +00:00
[fx] added torchvision model tracing testing (#1216)
* [fx] added torchvision model tracing testing * remove unused imports
This commit is contained in:
@@ -1,31 +0,0 @@
|
||||
import torch
|
||||
import torch.nn
|
||||
|
||||
|
||||
def test_maxpool():
|
||||
layer_to_test = dict(maxpool_1d=dict(layer=torch.nn.MaxPool1d, shape=(4, 3, 4)),
|
||||
maxpool_2d=dict(layer=torch.nn.MaxPool2d, shape=(4, 3, 4, 4)))
|
||||
|
||||
for name, info in layer_to_test.items():
|
||||
data = torch.rand(*info['shape'])
|
||||
meta_data = data.to('meta')
|
||||
layer = info['layer'](kernel_size=3)
|
||||
out = layer(data)
|
||||
meta_out = layer(meta_data)
|
||||
assert meta_out.is_meta
|
||||
assert out.shape == meta_out.shape
|
||||
|
||||
|
||||
def test_avgpool():
|
||||
layer_to_test = dict(maxpool_1d=dict(layer=torch.nn.AvgPool1d, shape=(4, 3, 4)),
|
||||
maxpool_2d=dict(layer=torch.nn.AvgPool2d, shape=(4, 3, 4, 4)),
|
||||
maxpool_3d=dict(layer=torch.nn.AvgPool3d, shape=(4, 3, 4, 4, 4)))
|
||||
|
||||
for name, info in layer_to_test.items():
|
||||
data = torch.rand(*info['shape'])
|
||||
meta_data = data.to('meta')
|
||||
layer = info['layer'](kernel_size=3)
|
||||
out = layer(data)
|
||||
meta_out = layer(meta_data)
|
||||
assert meta_out.is_meta
|
||||
assert out.shape == meta_out.shape
|
@@ -227,31 +227,88 @@ def test_conv3d():
|
||||
output_shape=materialized_output.shape)
|
||||
|
||||
|
||||
def test_maxpool3d():
|
||||
pooler = torch.nn.MaxPool3d(kernel_size=3)
|
||||
def test_pool1d():
|
||||
combinations = [[torch.nn.MaxPool1d, patched_module.torch_nn_maxpool1d],
|
||||
[torch.nn.AvgPool1d, patched_module.torch_nn_avgpool1d]]
|
||||
|
||||
# test max pool 3d
|
||||
data = torch.rand(2, 3, 4, 4, 4)
|
||||
materialized_output = pooler(data)
|
||||
_assert_output_shape(data=data,
|
||||
module=pooler,
|
||||
patch_fn=patched_module.torch_nn_maxpool3d,
|
||||
expect_exception=False,
|
||||
output_shape=materialized_output.shape)
|
||||
for (layer_cls, patch_func) in combinations:
|
||||
pooler = layer_cls(kernel_size=3)
|
||||
|
||||
# test max pool 3d
|
||||
data = torch.rand(2, 3, 4, 4)
|
||||
materialized_output = pooler(data)
|
||||
_assert_output_shape(data=data,
|
||||
module=pooler,
|
||||
patch_fn=patched_module.torch_nn_maxpool3d,
|
||||
expect_exception=False,
|
||||
output_shape=materialized_output.shape)
|
||||
data = torch.rand(2, 3, 4)
|
||||
materialized_output = pooler(data)
|
||||
_assert_output_shape(data=data,
|
||||
module=pooler,
|
||||
patch_fn=patch_func,
|
||||
expect_exception=False,
|
||||
output_shape=materialized_output.shape)
|
||||
|
||||
# test max pool 3d
|
||||
data = torch.rand(2, 3, 4)
|
||||
_assert_output_shape(data=data,
|
||||
module=pooler,
|
||||
patch_fn=patched_module.torch_nn_maxpool3d,
|
||||
expect_exception=True,
|
||||
output_shape=None)
|
||||
data = torch.rand(2, 4)
|
||||
materialized_output = pooler(data)
|
||||
_assert_output_shape(data=data,
|
||||
module=pooler,
|
||||
patch_fn=patch_func,
|
||||
expect_exception=False,
|
||||
output_shape=materialized_output.shape)
|
||||
|
||||
data = torch.rand(2, 3, 4, 4)
|
||||
_assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=True, output_shape=None)
|
||||
|
||||
|
||||
def test_pool2d():
|
||||
combinations = [[torch.nn.MaxPool2d, patched_module.torch_nn_maxpool2d],
|
||||
[torch.nn.AvgPool2d, patched_module.torch_nn_avgpool2d]]
|
||||
|
||||
for (layer_cls, patch_func) in combinations:
|
||||
pooler = layer_cls(kernel_size=3)
|
||||
|
||||
# test max pool 3d
|
||||
data = torch.rand(2, 3, 4, 4)
|
||||
materialized_output = pooler(data)
|
||||
_assert_output_shape(data=data,
|
||||
module=pooler,
|
||||
patch_fn=patch_func,
|
||||
expect_exception=False,
|
||||
output_shape=materialized_output.shape)
|
||||
|
||||
# test max pool 3d
|
||||
data = torch.rand(2, 4, 4)
|
||||
materialized_output = pooler(data)
|
||||
_assert_output_shape(data=data,
|
||||
module=pooler,
|
||||
patch_fn=patch_func,
|
||||
expect_exception=False,
|
||||
output_shape=materialized_output.shape)
|
||||
|
||||
# test max pool 3d
|
||||
data = torch.rand(2, 3, 4, 4, 4)
|
||||
_assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=True, output_shape=None)
|
||||
|
||||
|
||||
def test_pool3d():
|
||||
combinations = [[torch.nn.MaxPool3d, patched_module.torch_nn_maxpool3d],
|
||||
[torch.nn.AvgPool3d, patched_module.torch_nn_avgpool3d]]
|
||||
|
||||
for (layer_cls, patch_func) in combinations:
|
||||
pooler = layer_cls(kernel_size=3)
|
||||
|
||||
# test max pool 3d
|
||||
data = torch.rand(2, 3, 4, 4, 4)
|
||||
materialized_output = pooler(data)
|
||||
_assert_output_shape(data=data,
|
||||
module=pooler,
|
||||
patch_fn=patch_func,
|
||||
expect_exception=False,
|
||||
output_shape=materialized_output.shape)
|
||||
|
||||
# test max pool 3d
|
||||
data = torch.rand(2, 4, 4, 4)
|
||||
materialized_output = pooler(data)
|
||||
_assert_output_shape(data=data,
|
||||
module=pooler,
|
||||
patch_fn=patch_func,
|
||||
expect_exception=False,
|
||||
output_shape=materialized_output.shape)
|
||||
|
||||
# test max pool 3d
|
||||
data = torch.rand(2, 3, 4)
|
||||
_assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=True, output_shape=None)
|
||||
|
@@ -0,0 +1,46 @@
|
||||
import torch
|
||||
import pytest
|
||||
try:
|
||||
import torchvision.models as tm
|
||||
except:
|
||||
pass
|
||||
from colossalai.fx import ColoTracer
|
||||
from torch.fx import GraphModule
|
||||
|
||||
|
||||
@pytest.mark.skip('skip as torchvision is required')
|
||||
def test_torchvision_models():
|
||||
MODEL_LIST = [
|
||||
tm.vgg11, tm.resnet18, tm.densenet121, tm.mobilenet_v3_small, tm.resnext50_32x4d, tm.wide_resnet50_2,
|
||||
tm.regnet_x_16gf, tm.vit_b_16, tm.convnext_small, tm.mnasnet0_5, tm.efficientnet_b0
|
||||
]
|
||||
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
||||
tracer = ColoTracer()
|
||||
data = torch.rand(2, 3, 224, 224)
|
||||
|
||||
for model_cls in MODEL_LIST:
|
||||
if model_cls in [tm.convnext_small, tm.efficientnet_b0]:
|
||||
# remove the impact of randomicity
|
||||
model = model_cls(stochastic_depth_prob=0)
|
||||
else:
|
||||
model = model_cls()
|
||||
|
||||
graph = tracer.trace(root=model)
|
||||
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
gm.recompile()
|
||||
|
||||
model.eval()
|
||||
gm.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
fx_out = gm(data)
|
||||
non_fx_out = model(data)
|
||||
assert torch.allclose(
|
||||
fx_out, non_fx_out), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_torchvision_models()
|
Reference in New Issue
Block a user