From b6cb5a47ada3fec7151e0f3ff87e07c4355bb46c Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Thu, 7 Jul 2022 14:02:17 +0800 Subject: [PATCH] [fx] added timm model tracing testing (#1221) --- .../fx/tracer/meta_patch/patched_function.py | 42 +++++++++- .../fx/tracer/meta_patch/patched_module.py | 4 +- tests/test_fx/test_coloproxy.py | 1 - .../test_timm_model/test_timm_model.py | 82 +++++++++++++++++++ 4 files changed, 125 insertions(+), 4 deletions(-) create mode 100644 tests/test_fx/test_tracer/test_timm_model/test_timm_model.py diff --git a/colossalai/fx/tracer/meta_patch/patched_function.py b/colossalai/fx/tracer/meta_patch/patched_function.py index d1457d89e..a9ddda8e8 100644 --- a/colossalai/fx/tracer/meta_patch/patched_function.py +++ b/colossalai/fx/tracer/meta_patch/patched_function.py @@ -1,3 +1,4 @@ +from curses import meta import operator import torch from .registry import meta_patched_function @@ -99,7 +100,6 @@ def torch_abs(input, *, out=None): @meta_patched_function.register(torch.nn.functional.relu) def torch_nn_func_relu(input, inplace=False): - assert not inplace, 'inplace is not supported yet' return torch.empty(input.shape, device='meta') @@ -178,3 +178,43 @@ def torch_unsqueeze(input, dim): @meta_patched_function.register(torch.Tensor.unsqueeze) def torch_tensor_unsqueeze(self, dim): return torch_unsqueeze(self, dim) + + +@meta_patched_function.register(torch.nn.functional.layer_norm) +def torch_nn_func_layernorm(input, normalized_shape, weight=None, bias=None, eps=1e-05): + return torch.empty(input.shape, device='meta') + + +@meta_patched_function.register(torch.nn.functional.batch_norm) +def torch_nn_func_batchnorm(input, + running_mean, + running_var, + weight=None, + bias=None, + training=False, + momentum=0.1, + eps=1e-05): + return torch.empty(input.shape, device='meta') + + +@meta_patched_function.register(torch.var_mean) +def torch_var_mean(input, dim, unbiased=True, keepdim=False, *, out=None): + assert out is None, 'saving to out is not supported yet' + var = torch.empty(1).squeeze(0).to('meta') + mean = torch.empty(1).squeeze(0).to('meta') + return var, mean + + +@meta_patched_function.register(torch.cat) +def torch_cat(tensors, dim=None, axis=None, *, out=None): + if dim is None and axis is None: + dim = 0 + if dim is None and axis is not None: + dim = axis + if dim < 0: + dim = tensors[0].dim() + dim + shapes = [t.shape for t in tensors] + shape = list(shapes[0]) + concatenated_dim = sum(shape[dim] for shape in shapes) + final_shape = shape[:dim] + [concatenated_dim] + shape[dim + 1:] + return torch.empty(final_shape, device="meta") diff --git a/colossalai/fx/tracer/meta_patch/patched_module.py b/colossalai/fx/tracer/meta_patch/patched_module.py index f895e73e9..2e2cedfe2 100644 --- a/colossalai/fx/tracer/meta_patch/patched_module.py +++ b/colossalai/fx/tracer/meta_patch/patched_module.py @@ -250,6 +250,6 @@ def torch_nn_maxpool3d(self, input): @meta_patched_module.register(torch.nn.ReLU) +@meta_patched_module.register(torch.nn.ReLU6) def torch_nn_func_relu(self, input): - assert not self.inplace, 'inplace is not supported yet' - return input.clone() + return torch.empty(input.shape, device='meta') diff --git a/tests/test_fx/test_coloproxy.py b/tests/test_fx/test_coloproxy.py index e498033d8..f3b34a4c0 100644 --- a/tests/test_fx/test_coloproxy.py +++ b/tests/test_fx/test_coloproxy.py @@ -3,7 +3,6 @@ from colossalai.fx.proxy import ColoProxy import pytest -@pytest.mark.skip def test_coloproxy(): # create a dummy node only for testing purpose model = torch.nn.Linear(10, 10) diff --git a/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py b/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py new file mode 100644 index 000000000..86a513643 --- /dev/null +++ b/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py @@ -0,0 +1,82 @@ +import torch +import pytest +try: + import timm.models as tm +except: + pass +from colossalai.fx import ColoTracer +from torch.fx import GraphModule + + +def trace_and_compare(model_cls, tracer, data, meta_args=None): + # trace + model = model_cls() + graph = tracer.trace(root=model, meta_args=meta_args) + gm = GraphModule(model, graph, model.__class__.__name__) + gm.recompile() + + # convert to eval for inference + model.eval() + gm.eval() + + # run forward + with torch.no_grad(): + fx_out = gm(data) + non_fx_out = model(data) + + # compare output + if isinstance(fx_out, tuple): + # some models produce tuple as output + for v1, v2 in zip(fx_out, non_fx_out): + assert torch.allclose(v1, v2), f'{model.__class__.__name__} has inconsistent outputs, {v1} vs {v2}' + else: + assert torch.allclose( + fx_out, non_fx_out), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}' + + +@pytest.mark.skip('skip as timm is required') +def test_timm_models_without_control_flow(): + torch.backends.cudnn.deterministic = True + + MODEL_LIST = [ + tm.resnest.resnest50d, tm.beit.beit_base_patch16_224, tm.cait.cait_s24_224, tm.convmixer.convmixer_768_32, + tm.efficientnet.efficientnetv2_m, tm.resmlp_12_224, tm.vision_transformer.vit_base_patch16_224 + + # results not aligned + # tm.deit_base_distilled_patch16_224, + ] + + tracer = ColoTracer() + data = torch.rand(2, 3, 224, 224) + + for model_cls in MODEL_LIST: + trace_and_compare(model_cls, tracer, data) + + +@pytest.mark.skip('skip as timm is required') +def test_timm_models_with_control_flow(): + torch.backends.cudnn.deterministic = True + + MODEL_LIST_WITH_CONTROL_FLOW = [ + tm.convnext.convnext_base, + tm.vgg.vgg11, + + # results not aligned + # tm.dpn.dpn68, + # tm.densenet.densenet121, + # tm.rexnet.rexnet_100, + # tm.swin_transformer.swin_base_patch4_window7_224 + ] + + tracer = ColoTracer() + data = torch.rand(2, 3, 224, 224) + + meta_args = {'x': data.to('meta')} + + for model_cls in MODEL_LIST_WITH_CONTROL_FLOW: + trace_and_compare(model_cls, tracer, data, meta_args) + + +if __name__ == '__main__': + test_timm_models_with_control_flow() + test_timm_models_without_control_flow()