From 86ac782d7c8b20289fe42b70ab09dea86024b353 Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Tue, 14 Mar 2023 14:29:18 +0800 Subject: [PATCH] [test] added timm models to test model zoo (#3129) * [test] added timm models to test model zoo * polish code * polish code * polish code * polish code * polish code --- tests/kit/__init__.py | 0 tests/kit/model_zoo/__init__.py | 4 + tests/kit/model_zoo/registry.py | 63 +++++++ tests/kit/model_zoo/timm/__init__.py | 1 + tests/kit/model_zoo/timm/timm.py | 159 ++++++++++++++++++ .../test_timm_model/test_timm_model.py | 74 ++++---- 6 files changed, 258 insertions(+), 43 deletions(-) create mode 100644 tests/kit/__init__.py create mode 100644 tests/kit/model_zoo/__init__.py create mode 100644 tests/kit/model_zoo/registry.py create mode 100644 tests/kit/model_zoo/timm/__init__.py create mode 100644 tests/kit/model_zoo/timm/timm.py diff --git a/tests/kit/__init__.py b/tests/kit/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/kit/model_zoo/__init__.py b/tests/kit/model_zoo/__init__.py new file mode 100644 index 000000000..435daea2c --- /dev/null +++ b/tests/kit/model_zoo/__init__.py @@ -0,0 +1,4 @@ +from . import timm +from .registry import model_zoo + +__all__ = ['model_zoo'] diff --git a/tests/kit/model_zoo/registry.py b/tests/kit/model_zoo/registry.py new file mode 100644 index 000000000..4e7dcb30f --- /dev/null +++ b/tests/kit/model_zoo/registry.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python +from dataclasses import dataclass +from typing import Callable + +__all__ = ['ModelZooRegistry', 'ModelAttributem', 'model_zoo'] + + +@dataclass +class ModelAttribute: + """ + Attributes of a model. + """ + has_control_flow: bool = False + + +class ModelZooRegistry(dict): + """ + A registry to map model names to model and data generation functions. + """ + + def register(self, + name: str, + model_fn: Callable, + data_gen_fn: Callable, + output_transform_fn: Callable, + model_attribute: ModelAttribute = None): + """ + Register a model and data generation function. + + Examples: + >>> # Register + >>> model_zoo = ModelZooRegistry() + >>> model_zoo.register('resnet18', resnet18, resnet18_data_gen) + >>> # Run the model + >>> data = resnresnet18_data_gen() # do not input any argument + >>> model = resnet18() # do not input any argument + >>> out = model(**data) + + Args: + name (str): Name of the model. + model_fn (callable): A function that returns a model. **It must not contain any arguments.** + output_transform_fn (callable): A function that transforms the output of the model into Dict. + data_gen_fn (callable): A function that returns a data sample in the form of Dict. **It must not contain any arguments.** + model_attribute (ModelAttribute): Attributes of the model. Defaults to None. + """ + self[name] = (model_fn, data_gen_fn, output_transform_fn, model_attribute) + + def get_sub_registry(self, keyword: str): + """ + Get a sub registry with models that contain the keyword. + + Args: + keyword (str): Keyword to filter models. + """ + new_dict = dict() + + for k, v in self.items(): + if keyword in k: + new_dict[k] = v + return new_dict + + +model_zoo = ModelZooRegistry() diff --git a/tests/kit/model_zoo/timm/__init__.py b/tests/kit/model_zoo/timm/__init__.py new file mode 100644 index 000000000..c9c853194 --- /dev/null +++ b/tests/kit/model_zoo/timm/__init__.py @@ -0,0 +1 @@ +from .timm import * diff --git a/tests/kit/model_zoo/timm/timm.py b/tests/kit/model_zoo/timm/timm.py new file mode 100644 index 000000000..b29ac12a6 --- /dev/null +++ b/tests/kit/model_zoo/timm/timm.py @@ -0,0 +1,159 @@ +import timm.models as tm +import torch + +from ..registry import ModelAttribute, model_zoo + +## ============== +# Register models without control flow +## ============== +data_gen_fn = lambda: dict(x=torch.rand(2, 3, 224, 224)) +output_transform_fn = lambda x: dict(output=x) + +model_zoo.register(name='timm_resnet', + model_fn=tm.resnest.resnest50d, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) +model_zoo.register(name='timm_beit', + model_fn=tm.beit.beit_base_patch16_224, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) +model_zoo.register(name='timm_cait', + model_fn=tm.cait.cait_s24_224, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) +model_zoo.register(name='timm_convmixer', + model_fn=tm.convmixer.convmixer_768_32, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) +model_zoo.register(name='timm_efficientnetv2', + model_fn=tm.efficientnet.efficientnetv2_m, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) +model_zoo.register(name='timm_resmlp', + model_fn=tm.resmlp_12_224, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) +model_zoo.register(name='timm_vision_transformer', + model_fn=tm.vision_transformer.vit_base_patch16_224, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) +model_zoo.register(name='timm_deit', + model_fn=tm.deit_base_distilled_patch16_224, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) +model_zoo.register(name='timm_beitv2', + model_fn=tm.beitv2_base_patch16_224, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) +model_zoo.register(name='timm_coat', + model_fn=tm.coat.coat_lite_mini, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) + +model_zoo.register(name='timm_deit3', + model_fn=tm.deit3_base_patch16_224, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) + +model_zoo.register(name='timm_eca_nfnet', + model_fn=tm.eca_nfnet_l0, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) +model_zoo.register(name='timm_efficientformer', + model_fn=tm.efficientformer_l1, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) +model_zoo.register(name='timm_ese_vovnet19b_dw', + model_fn=tm.ese_vovnet19b_dw, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) +model_zoo.register(name='timm_gmixer_12_224', + model_fn=tm.gmixer_12_224, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) +model_zoo.register(name='timm_gmlp_b16_224', + model_fn=tm.gmlp_b16_224, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) +model_zoo.register(name='timm_hardcorenas_a', + model_fn=tm.hardcorenas_a, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) +model_zoo.register(name='timm_hrnet_w18_small', + model_fn=tm.hrnet_w18_small, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) +model_zoo.register(name='timm_inception_v3', + model_fn=tm.inception_v3, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) +model_zoo.register(name='timm_mixer_b16_224', + model_fn=tm.mixer_b16_224, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) +model_zoo.register(name='timm_nf_ecaresnet101', + model_fn=tm.nf_ecaresnet101, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) +model_zoo.register(name='timm_nf_regnet_b0', + model_fn=tm.nf_regnet_b0, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) +model_zoo.register(name='timm_regnetv_040', + model_fn=tm.regnetv_040, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) +model_zoo.register(name='timm_skresnet18', + model_fn=tm.skresnet18, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) +model_zoo.register(name='timm_tnt_b_patch16_224', + model_fn=tm.tnt_b_patch16_224, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) +model_zoo.register(name='timm_wide_resnet50_2', + model_fn=tm.wide_resnet50_2, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) +model_zoo.register(name='timm_convit', + model_fn=tm.convit_base, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) +model_zoo.register(name='timm_dm_nfnet', + model_fn=tm.dm_nfnet_f0, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn) + +# ============== +# Register models with control flow +# ============== +model_zoo.register(name='timm_convnext', + model_fn=tm.convnext.convnext_base, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register(name='timm_vgg', + model_fn=tm.vgg.vgg11, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register(name='timm_dpn', + model_fn=tm.dpn.dpn68, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register(name='timm_densenet', + model_fn=tm.densenet.densenet121, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register(name='timm_rexnet', + model_fn=tm.rexnet.rexnet_100, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register(name='timm_swin_transformer', + model_fn=tm.swin_transformer.swin_base_patch4_window7_224, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True)) diff --git a/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py b/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py index 28ec3d825..31baa3e89 100644 --- a/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py +++ b/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py @@ -3,9 +3,10 @@ import timm.models as tm import torch from colossalai.fx import symbolic_trace +from tests.kit.model_zoo import model_zoo -def trace_and_compare(model_cls, data, meta_args=None): +def trace_and_compare(model_cls, data, output_transform_fn, meta_args=None): # trace model = model_cls() @@ -14,60 +15,47 @@ def trace_and_compare(model_cls, data, meta_args=None): # without this statement, the torch.nn.functional.batch_norm will always be in training mode model.eval() + # TODO: support the following models + # 1. ConViT + # 2. NormFreeNet + # as they are not supported, let's skip them + if model.__class__.__name__ in ['ConViT', 'NormFreeNet']: + return + gm = symbolic_trace(model, meta_args=meta_args) # run forward with torch.no_grad(): - fx_out = gm(data) - non_fx_out = model(data) + fx_out = gm(**data) + non_fx_out = model(**data) # compare output - if isinstance(fx_out, tuple): - # some models produce tuple as output - for v1, v2 in zip(fx_out, non_fx_out): - assert torch.allclose(v1, v2), f'{model.__class__.__name__} has inconsistent outputs, {v1} vs {v2}' - else: - assert torch.allclose( - fx_out, non_fx_out, - atol=1e-5), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}' + transformed_fx_out = output_transform_fn(fx_out) + transformed_non_fx_out = output_transform_fn(non_fx_out) + + assert len(transformed_fx_out) == len(transformed_non_fx_out) + + for key in transformed_fx_out.keys(): + fx_output_val = transformed_fx_out[key] + non_fx_output_val = transformed_non_fx_out[key] + assert torch.allclose(fx_output_val, non_fx_output_val, atol=1e-5), \ + f'{model.__class__.__name__} has inconsistent outputs, {fx_output_val} vs {non_fx_output_val}' -def test_timm_models_without_control_flow(): +def test_timm_models(): torch.backends.cudnn.deterministic = True - MODEL_LIST = [ - tm.resnest.resnest50d, - tm.beit.beit_base_patch16_224, - tm.cait.cait_s24_224, - tm.convmixer.convmixer_768_32, - tm.efficientnet.efficientnetv2_m, - tm.resmlp_12_224, - tm.vision_transformer.vit_base_patch16_224, - tm.deit_base_distilled_patch16_224, - ] + sub_model_zoo = model_zoo.get_sub_registry('timm') - data = torch.rand(2, 3, 224, 224) + for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in sub_model_zoo.items(): + data = data_gen_fn() + if attribute is not None and attribute.has_control_flow: + meta_args = {k: v.to('meta') for k, v in data.items()} + else: + meta_args = None - for model_cls in MODEL_LIST: - trace_and_compare(model_cls, data) - - -def test_timm_models_with_control_flow(): - torch.backends.cudnn.deterministic = True - - MODEL_LIST_WITH_CONTROL_FLOW = [ - tm.convnext.convnext_base, tm.vgg.vgg11, tm.dpn.dpn68, tm.densenet.densenet121, tm.rexnet.rexnet_100, - tm.swin_transformer.swin_base_patch4_window7_224 - ] - - data = torch.rand(2, 3, 224, 224) - - meta_args = {'x': data.to('meta')} - - for model_cls in MODEL_LIST_WITH_CONTROL_FLOW: - trace_and_compare(model_cls, data, meta_args) + trace_and_compare(model_fn, data, output_transform_fn, meta_args) if __name__ == '__main__': - test_timm_models_with_control_flow() - test_timm_models_without_control_flow() + test_timm_models()