diff --git a/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py b/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py index 86a513643..b69900d0b 100644 --- a/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py +++ b/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py @@ -11,14 +11,16 @@ from torch.fx import GraphModule def trace_and_compare(model_cls, tracer, data, meta_args=None): # trace model = model_cls() + + # convert to eval for inference + # it is important to set it to eval mode before tracing + # without this statement, the torch.nn.functional.batch_norm will always be in training mode + model.eval() + graph = tracer.trace(root=model, meta_args=meta_args) gm = GraphModule(model, graph, model.__class__.__name__) gm.recompile() - # convert to eval for inference - model.eval() - gm.eval() - # run forward with torch.no_grad(): fx_out = gm(data) @@ -39,11 +41,14 @@ def test_timm_models_without_control_flow(): torch.backends.cudnn.deterministic = True MODEL_LIST = [ - tm.resnest.resnest50d, tm.beit.beit_base_patch16_224, tm.cait.cait_s24_224, tm.convmixer.convmixer_768_32, - tm.efficientnet.efficientnetv2_m, tm.resmlp_12_224, tm.vision_transformer.vit_base_patch16_224 - - # results not aligned - # tm.deit_base_distilled_patch16_224, + tm.resnest.resnest50d, + tm.beit.beit_base_patch16_224, + tm.cait.cait_s24_224, + tm.convmixer.convmixer_768_32, + tm.efficientnet.efficientnetv2_m, + tm.resmlp_12_224, + tm.vision_transformer.vit_base_patch16_224, + tm.deit_base_distilled_patch16_224, ] tracer = ColoTracer() @@ -60,11 +65,11 @@ def test_timm_models_with_control_flow(): MODEL_LIST_WITH_CONTROL_FLOW = [ tm.convnext.convnext_base, tm.vgg.vgg11, + tm.dpn.dpn68, + tm.densenet.densenet121, + tm.rexnet.rexnet_100, - # results not aligned - # tm.dpn.dpn68, - # tm.densenet.densenet121, - # tm.rexnet.rexnet_100, + # not traceable # tm.swin_transformer.swin_base_patch4_window7_224 ]