mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-21 21:22:04 +00:00
[fx] fixed timm tracing result misalignment (#1225)
This commit is contained in:
parent
b6cb5a47ad
commit
37fcf96b7f
@ -11,14 +11,16 @@ from torch.fx import GraphModule
|
|||||||
def trace_and_compare(model_cls, tracer, data, meta_args=None):
|
def trace_and_compare(model_cls, tracer, data, meta_args=None):
|
||||||
# trace
|
# trace
|
||||||
model = model_cls()
|
model = model_cls()
|
||||||
|
|
||||||
|
# convert to eval for inference
|
||||||
|
# it is important to set it to eval mode before tracing
|
||||||
|
# without this statement, the torch.nn.functional.batch_norm will always be in training mode
|
||||||
|
model.eval()
|
||||||
|
|
||||||
graph = tracer.trace(root=model, meta_args=meta_args)
|
graph = tracer.trace(root=model, meta_args=meta_args)
|
||||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||||
gm.recompile()
|
gm.recompile()
|
||||||
|
|
||||||
# convert to eval for inference
|
|
||||||
model.eval()
|
|
||||||
gm.eval()
|
|
||||||
|
|
||||||
# run forward
|
# run forward
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
fx_out = gm(data)
|
fx_out = gm(data)
|
||||||
@ -39,11 +41,14 @@ def test_timm_models_without_control_flow():
|
|||||||
torch.backends.cudnn.deterministic = True
|
torch.backends.cudnn.deterministic = True
|
||||||
|
|
||||||
MODEL_LIST = [
|
MODEL_LIST = [
|
||||||
tm.resnest.resnest50d, tm.beit.beit_base_patch16_224, tm.cait.cait_s24_224, tm.convmixer.convmixer_768_32,
|
tm.resnest.resnest50d,
|
||||||
tm.efficientnet.efficientnetv2_m, tm.resmlp_12_224, tm.vision_transformer.vit_base_patch16_224
|
tm.beit.beit_base_patch16_224,
|
||||||
|
tm.cait.cait_s24_224,
|
||||||
# results not aligned
|
tm.convmixer.convmixer_768_32,
|
||||||
# tm.deit_base_distilled_patch16_224,
|
tm.efficientnet.efficientnetv2_m,
|
||||||
|
tm.resmlp_12_224,
|
||||||
|
tm.vision_transformer.vit_base_patch16_224,
|
||||||
|
tm.deit_base_distilled_patch16_224,
|
||||||
]
|
]
|
||||||
|
|
||||||
tracer = ColoTracer()
|
tracer = ColoTracer()
|
||||||
@ -60,11 +65,11 @@ def test_timm_models_with_control_flow():
|
|||||||
MODEL_LIST_WITH_CONTROL_FLOW = [
|
MODEL_LIST_WITH_CONTROL_FLOW = [
|
||||||
tm.convnext.convnext_base,
|
tm.convnext.convnext_base,
|
||||||
tm.vgg.vgg11,
|
tm.vgg.vgg11,
|
||||||
|
tm.dpn.dpn68,
|
||||||
|
tm.densenet.densenet121,
|
||||||
|
tm.rexnet.rexnet_100,
|
||||||
|
|
||||||
# results not aligned
|
# not traceable
|
||||||
# tm.dpn.dpn68,
|
|
||||||
# tm.densenet.densenet121,
|
|
||||||
# tm.rexnet.rexnet_100,
|
|
||||||
# tm.swin_transformer.swin_base_patch4_window7_224
|
# tm.swin_transformer.swin_base_patch4_window7_224
|
||||||
]
|
]
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user