[fx] added timm model tracing testing (#1221)

This commit is contained in:
Frank Lee
2022-07-07 14:02:17 +08:00
committed by GitHub
parent 280a81243d
commit b6cb5a47ad
4 changed files with 125 additions and 4 deletions

View File

@@ -3,7 +3,6 @@ from colossalai.fx.proxy import ColoProxy
import pytest
@pytest.mark.skip
def test_coloproxy():
# create a dummy node only for testing purpose
model = torch.nn.Linear(10, 10)

View File

@@ -0,0 +1,82 @@
import torch
import pytest
try:
import timm.models as tm
except:
pass
from colossalai.fx import ColoTracer
from torch.fx import GraphModule
def trace_and_compare(model_cls, tracer, data, meta_args=None):
# trace
model = model_cls()
graph = tracer.trace(root=model, meta_args=meta_args)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
# convert to eval for inference
model.eval()
gm.eval()
# run forward
with torch.no_grad():
fx_out = gm(data)
non_fx_out = model(data)
# compare output
if isinstance(fx_out, tuple):
# some models produce tuple as output
for v1, v2 in zip(fx_out, non_fx_out):
assert torch.allclose(v1, v2), f'{model.__class__.__name__} has inconsistent outputs, {v1} vs {v2}'
else:
assert torch.allclose(
fx_out, non_fx_out), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
@pytest.mark.skip('skip as timm is required')
def test_timm_models_without_control_flow():
torch.backends.cudnn.deterministic = True
MODEL_LIST = [
tm.resnest.resnest50d, tm.beit.beit_base_patch16_224, tm.cait.cait_s24_224, tm.convmixer.convmixer_768_32,
tm.efficientnet.efficientnetv2_m, tm.resmlp_12_224, tm.vision_transformer.vit_base_patch16_224
# results not aligned
# tm.deit_base_distilled_patch16_224,
]
tracer = ColoTracer()
data = torch.rand(2, 3, 224, 224)
for model_cls in MODEL_LIST:
trace_and_compare(model_cls, tracer, data)
@pytest.mark.skip('skip as timm is required')
def test_timm_models_with_control_flow():
torch.backends.cudnn.deterministic = True
MODEL_LIST_WITH_CONTROL_FLOW = [
tm.convnext.convnext_base,
tm.vgg.vgg11,
# results not aligned
# tm.dpn.dpn68,
# tm.densenet.densenet121,
# tm.rexnet.rexnet_100,
# tm.swin_transformer.swin_base_patch4_window7_224
]
tracer = ColoTracer()
data = torch.rand(2, 3, 224, 224)
meta_args = {'x': data.to('meta')}
for model_cls in MODEL_LIST_WITH_CONTROL_FLOW:
trace_and_compare(model_cls, tracer, data, meta_args)
if __name__ == '__main__':
test_timm_models_with_control_flow()
test_timm_models_without_control_flow()