mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 01:28:31 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from . import diffusers, timm, torchaudio, torchrec, torchvision, transformers
|
||||
from .registry import model_zoo
|
||||
|
||||
__all__ = ['model_zoo']
|
||||
__all__ = ["model_zoo"]
|
||||
|
@@ -4,7 +4,7 @@ import diffusers
|
||||
import torch
|
||||
import transformers
|
||||
|
||||
from ..registry import ModelAttribute, model_zoo
|
||||
from ..registry import model_zoo
|
||||
|
||||
BATCH_SIZE = 2
|
||||
SEQ_LENGTH = 5
|
||||
@@ -26,10 +26,9 @@ def data_clip_model():
|
||||
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
|
||||
position_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
|
||||
pixel_values = torch.zeros((BATCH_SIZE, IN_CHANNELS, HEIGHT, WIDTH), dtype=torch.float32)
|
||||
return dict(input_ids=input_ids,
|
||||
pixel_values=pixel_values,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids)
|
||||
return dict(
|
||||
input_ids=input_ids, pixel_values=pixel_values, attention_mask=attention_mask, position_ids=position_ids
|
||||
)
|
||||
|
||||
|
||||
def data_clip_text():
|
||||
@@ -43,32 +42,41 @@ def data_clip_vision():
|
||||
return dict(pixel_values=pixel_values)
|
||||
|
||||
|
||||
model_zoo.register(name='diffusers_auto_encoder_kl',
|
||||
model_fn=diffusers.AutoencoderKL,
|
||||
data_gen_fn=data_vae_fn,
|
||||
output_transform_fn=identity_output)
|
||||
model_zoo.register(
|
||||
name="diffusers_auto_encoder_kl",
|
||||
model_fn=diffusers.AutoencoderKL,
|
||||
data_gen_fn=data_vae_fn,
|
||||
output_transform_fn=identity_output,
|
||||
)
|
||||
|
||||
model_zoo.register(name='diffusers_vq_model',
|
||||
model_fn=diffusers.VQModel,
|
||||
data_gen_fn=data_vae_fn,
|
||||
output_transform_fn=identity_output)
|
||||
model_zoo.register(
|
||||
name="diffusers_vq_model", model_fn=diffusers.VQModel, data_gen_fn=data_vae_fn, output_transform_fn=identity_output
|
||||
)
|
||||
|
||||
model_zoo.register(name='diffusers_clip_model',
|
||||
model_fn=partial(transformers.CLIPModel, config=transformers.CLIPConfig()),
|
||||
data_gen_fn=data_clip_model,
|
||||
output_transform_fn=identity_output)
|
||||
model_zoo.register(
|
||||
name="diffusers_clip_model",
|
||||
model_fn=partial(transformers.CLIPModel, config=transformers.CLIPConfig()),
|
||||
data_gen_fn=data_clip_model,
|
||||
output_transform_fn=identity_output,
|
||||
)
|
||||
|
||||
model_zoo.register(name='diffusers_clip_text_model',
|
||||
model_fn=partial(transformers.CLIPTextModel, config=transformers.CLIPTextConfig()),
|
||||
data_gen_fn=data_clip_text,
|
||||
output_transform_fn=identity_output)
|
||||
model_zoo.register(
|
||||
name="diffusers_clip_text_model",
|
||||
model_fn=partial(transformers.CLIPTextModel, config=transformers.CLIPTextConfig()),
|
||||
data_gen_fn=data_clip_text,
|
||||
output_transform_fn=identity_output,
|
||||
)
|
||||
|
||||
model_zoo.register(name='diffusers_clip_vision_model',
|
||||
model_fn=partial(transformers.CLIPVisionModel, config=transformers.CLIPVisionConfig()),
|
||||
data_gen_fn=data_clip_vision,
|
||||
output_transform_fn=clip_vision_model_output)
|
||||
model_zoo.register(
|
||||
name="diffusers_clip_vision_model",
|
||||
model_fn=partial(transformers.CLIPVisionModel, config=transformers.CLIPVisionConfig()),
|
||||
data_gen_fn=data_clip_vision,
|
||||
output_transform_fn=clip_vision_model_output,
|
||||
)
|
||||
|
||||
model_zoo.register(name='diffusers_unet2d_model',
|
||||
model_fn=diffusers.UNet2DModel,
|
||||
data_gen_fn=data_unet_fn,
|
||||
output_transform_fn=identity_output)
|
||||
model_zoo.register(
|
||||
name="diffusers_unet2d_model",
|
||||
model_fn=diffusers.UNet2DModel,
|
||||
data_gen_fn=data_unet_fn,
|
||||
output_transform_fn=identity_output,
|
||||
)
|
||||
|
@@ -2,7 +2,7 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable
|
||||
|
||||
__all__ = ['ModelZooRegistry', 'ModelAttribute', 'model_zoo']
|
||||
__all__ = ["ModelZooRegistry", "ModelAttribute", "model_zoo"]
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -14,6 +14,7 @@ class ModelAttribute:
|
||||
has_control_flow (bool): Whether the model contains branching in its forward method.
|
||||
has_stochastic_depth_prob (bool): Whether the model contains stochastic depth probability. Often seen in the torchvision models.
|
||||
"""
|
||||
|
||||
has_control_flow: bool = False
|
||||
has_stochastic_depth_prob: bool = False
|
||||
|
||||
@@ -23,13 +24,15 @@ class ModelZooRegistry(dict):
|
||||
A registry to map model names to model and data generation functions.
|
||||
"""
|
||||
|
||||
def register(self,
|
||||
name: str,
|
||||
model_fn: Callable,
|
||||
data_gen_fn: Callable,
|
||||
output_transform_fn: Callable,
|
||||
loss_fn: Callable = None,
|
||||
model_attribute: ModelAttribute = None):
|
||||
def register(
|
||||
self,
|
||||
name: str,
|
||||
model_fn: Callable,
|
||||
data_gen_fn: Callable,
|
||||
output_transform_fn: Callable,
|
||||
loss_fn: Callable = None,
|
||||
model_attribute: ModelAttribute = None,
|
||||
):
|
||||
"""
|
||||
Register a model and data generation function.
|
||||
|
||||
@@ -71,7 +74,7 @@ class ModelZooRegistry(dict):
|
||||
if keyword in k:
|
||||
new_dict[k] = v
|
||||
|
||||
assert len(new_dict) > 0, f'No model found with keyword {keyword}'
|
||||
assert len(new_dict) > 0, f"No model found with keyword {keyword}"
|
||||
return new_dict
|
||||
|
||||
|
||||
|
@@ -9,151 +9,183 @@ from ..registry import ModelAttribute, model_zoo
|
||||
data_gen_fn = lambda: dict(x=torch.rand(2, 3, 224, 224))
|
||||
output_transform_fn = lambda x: dict(output=x)
|
||||
|
||||
model_zoo.register(name='timm_resnet',
|
||||
model_fn=tm.resnest.resnest50d,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(name='timm_beit',
|
||||
model_fn=tm.beit.beit_base_patch16_224,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(name='timm_cait',
|
||||
model_fn=tm.cait.cait_s24_224,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(name='timm_convmixer',
|
||||
model_fn=tm.convmixer.convmixer_768_32,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(name='timm_efficientnetv2',
|
||||
model_fn=tm.efficientnet.efficientnetv2_m,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(name='timm_resmlp',
|
||||
model_fn=tm.resmlp_12_224,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(name='timm_vision_transformer',
|
||||
model_fn=tm.vision_transformer.vit_base_patch16_224,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(name='timm_deit',
|
||||
model_fn=tm.deit_base_distilled_patch16_224,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(name='timm_beitv2',
|
||||
model_fn=tm.beitv2_base_patch16_224,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(name='timm_coat',
|
||||
model_fn=tm.coat.coat_lite_mini,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(
|
||||
name="timm_resnet", model_fn=tm.resnest.resnest50d, data_gen_fn=data_gen_fn, output_transform_fn=output_transform_fn
|
||||
)
|
||||
model_zoo.register(
|
||||
name="timm_beit",
|
||||
model_fn=tm.beit.beit_base_patch16_224,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
)
|
||||
model_zoo.register(
|
||||
name="timm_cait", model_fn=tm.cait.cait_s24_224, data_gen_fn=data_gen_fn, output_transform_fn=output_transform_fn
|
||||
)
|
||||
model_zoo.register(
|
||||
name="timm_convmixer",
|
||||
model_fn=tm.convmixer.convmixer_768_32,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
)
|
||||
model_zoo.register(
|
||||
name="timm_efficientnetv2",
|
||||
model_fn=tm.efficientnet.efficientnetv2_m,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
)
|
||||
model_zoo.register(
|
||||
name="timm_resmlp", model_fn=tm.resmlp_12_224, data_gen_fn=data_gen_fn, output_transform_fn=output_transform_fn
|
||||
)
|
||||
model_zoo.register(
|
||||
name="timm_vision_transformer",
|
||||
model_fn=tm.vision_transformer.vit_base_patch16_224,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
)
|
||||
model_zoo.register(
|
||||
name="timm_deit",
|
||||
model_fn=tm.deit_base_distilled_patch16_224,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
)
|
||||
model_zoo.register(
|
||||
name="timm_beitv2",
|
||||
model_fn=tm.beitv2_base_patch16_224,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
)
|
||||
model_zoo.register(
|
||||
name="timm_coat", model_fn=tm.coat.coat_lite_mini, data_gen_fn=data_gen_fn, output_transform_fn=output_transform_fn
|
||||
)
|
||||
|
||||
model_zoo.register(name='timm_deit3',
|
||||
model_fn=tm.deit3_base_patch16_224,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(
|
||||
name="timm_deit3",
|
||||
model_fn=tm.deit3_base_patch16_224,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
)
|
||||
|
||||
model_zoo.register(name='timm_eca_nfnet',
|
||||
model_fn=tm.eca_nfnet_l0,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(name='timm_efficientformer',
|
||||
model_fn=tm.efficientformer_l1,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(name='timm_ese_vovnet19b_dw',
|
||||
model_fn=tm.ese_vovnet19b_dw,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(name='timm_gmixer_12_224',
|
||||
model_fn=tm.gmixer_12_224,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(name='timm_gmlp_b16_224',
|
||||
model_fn=tm.gmlp_b16_224,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(name='timm_hardcorenas_a',
|
||||
model_fn=tm.hardcorenas_a,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(name='timm_hrnet_w18_small',
|
||||
model_fn=tm.hrnet_w18_small,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(name='timm_inception_v3',
|
||||
model_fn=tm.inception_v3,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(name='timm_mixer_b16_224',
|
||||
model_fn=tm.mixer_b16_224,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(name='timm_nf_ecaresnet101',
|
||||
model_fn=tm.nf_ecaresnet101,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(name='timm_nf_regnet_b0',
|
||||
model_fn=tm.nf_regnet_b0,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(name='timm_regnetv_040',
|
||||
model_fn=tm.regnetv_040,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(name='timm_skresnet18',
|
||||
model_fn=tm.skresnet18,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(name='timm_tnt_b_patch16_224',
|
||||
model_fn=tm.tnt_b_patch16_224,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(name='timm_wide_resnet50_2',
|
||||
model_fn=tm.wide_resnet50_2,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(name='timm_convit',
|
||||
model_fn=tm.convit_base,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(name='timm_dm_nfnet',
|
||||
model_fn=tm.dm_nfnet_f0,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(
|
||||
name="timm_eca_nfnet", model_fn=tm.eca_nfnet_l0, data_gen_fn=data_gen_fn, output_transform_fn=output_transform_fn
|
||||
)
|
||||
model_zoo.register(
|
||||
name="timm_efficientformer",
|
||||
model_fn=tm.efficientformer_l1,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
)
|
||||
model_zoo.register(
|
||||
name="timm_ese_vovnet19b_dw",
|
||||
model_fn=tm.ese_vovnet19b_dw,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
)
|
||||
model_zoo.register(
|
||||
name="timm_gmixer_12_224",
|
||||
model_fn=tm.gmixer_12_224,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
)
|
||||
model_zoo.register(
|
||||
name="timm_gmlp_b16_224", model_fn=tm.gmlp_b16_224, data_gen_fn=data_gen_fn, output_transform_fn=output_transform_fn
|
||||
)
|
||||
model_zoo.register(
|
||||
name="timm_hardcorenas_a",
|
||||
model_fn=tm.hardcorenas_a,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
)
|
||||
model_zoo.register(
|
||||
name="timm_hrnet_w18_small",
|
||||
model_fn=tm.hrnet_w18_small,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
)
|
||||
model_zoo.register(
|
||||
name="timm_inception_v3", model_fn=tm.inception_v3, data_gen_fn=data_gen_fn, output_transform_fn=output_transform_fn
|
||||
)
|
||||
model_zoo.register(
|
||||
name="timm_mixer_b16_224",
|
||||
model_fn=tm.mixer_b16_224,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
)
|
||||
model_zoo.register(
|
||||
name="timm_nf_ecaresnet101",
|
||||
model_fn=tm.nf_ecaresnet101,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
)
|
||||
model_zoo.register(
|
||||
name="timm_nf_regnet_b0", model_fn=tm.nf_regnet_b0, data_gen_fn=data_gen_fn, output_transform_fn=output_transform_fn
|
||||
)
|
||||
model_zoo.register(
|
||||
name="timm_regnetv_040", model_fn=tm.regnetv_040, data_gen_fn=data_gen_fn, output_transform_fn=output_transform_fn
|
||||
)
|
||||
model_zoo.register(
|
||||
name="timm_skresnet18", model_fn=tm.skresnet18, data_gen_fn=data_gen_fn, output_transform_fn=output_transform_fn
|
||||
)
|
||||
model_zoo.register(
|
||||
name="timm_tnt_b_patch16_224",
|
||||
model_fn=tm.tnt_b_patch16_224,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
)
|
||||
model_zoo.register(
|
||||
name="timm_wide_resnet50_2",
|
||||
model_fn=tm.wide_resnet50_2,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
)
|
||||
model_zoo.register(
|
||||
name="timm_convit", model_fn=tm.convit_base, data_gen_fn=data_gen_fn, output_transform_fn=output_transform_fn
|
||||
)
|
||||
model_zoo.register(
|
||||
name="timm_dm_nfnet", model_fn=tm.dm_nfnet_f0, data_gen_fn=data_gen_fn, output_transform_fn=output_transform_fn
|
||||
)
|
||||
|
||||
# ==============
|
||||
# Register models with control flow
|
||||
# ==============
|
||||
model_zoo.register(name='timm_convnext',
|
||||
model_fn=tm.convnext.convnext_base,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='timm_vgg',
|
||||
model_fn=tm.vgg.vgg11,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='timm_dpn',
|
||||
model_fn=tm.dpn.dpn68,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='timm_densenet',
|
||||
model_fn=tm.densenet.densenet121,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='timm_rexnet',
|
||||
model_fn=tm.rexnet.rexnet_100,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='timm_swin_transformer',
|
||||
model_fn=tm.swin_transformer.swin_base_patch4_window7_224,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(
|
||||
name="timm_convnext",
|
||||
model_fn=tm.convnext.convnext_base,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
name="timm_vgg",
|
||||
model_fn=tm.vgg.vgg11,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
name="timm_dpn",
|
||||
model_fn=tm.dpn.dpn68,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
name="timm_densenet",
|
||||
model_fn=tm.densenet.densenet121,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
name="timm_rexnet",
|
||||
model_fn=tm.rexnet.rexnet_100,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
name="timm_swin_transformer",
|
||||
model_fn=tm.swin_transformer.swin_base_patch4_window7_224,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
|
@@ -23,24 +23,31 @@ def conformer_data_gen_fn():
|
||||
|
||||
transformer_output_transform_fn = lambda outputs: dict(frames=outputs[0], lengths=outputs[1])
|
||||
|
||||
model_zoo.register(name='torchaudio_conformer',
|
||||
model_fn=lambda: tm.Conformer(
|
||||
input_dim=INPUT_DIM, num_heads=4, ffn_dim=128, num_layers=4, depthwise_conv_kernel_size=31),
|
||||
data_gen_fn=conformer_data_gen_fn,
|
||||
output_transform_fn=transformer_output_transform_fn)
|
||||
model_zoo.register(
|
||||
name="torchaudio_conformer",
|
||||
model_fn=lambda: tm.Conformer(
|
||||
input_dim=INPUT_DIM, num_heads=4, ffn_dim=128, num_layers=4, depthwise_conv_kernel_size=31
|
||||
),
|
||||
data_gen_fn=conformer_data_gen_fn,
|
||||
output_transform_fn=transformer_output_transform_fn,
|
||||
)
|
||||
|
||||
single_output_transform_fn = lambda output: dict(output=output)
|
||||
|
||||
model_zoo.register(name='torchaudio_convtasnet',
|
||||
model_fn=tm.ConvTasNet,
|
||||
data_gen_fn=lambda: dict(input=torch.rand(4, 1, 8)),
|
||||
output_transform_fn=single_output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(
|
||||
name="torchaudio_convtasnet",
|
||||
model_fn=tm.ConvTasNet,
|
||||
data_gen_fn=lambda: dict(input=torch.rand(4, 1, 8)),
|
||||
output_transform_fn=single_output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
|
||||
model_zoo.register(name='torchaudio_deepspeech',
|
||||
model_fn=lambda: tm.DeepSpeech(IN_FEATURES, n_hidden=128, n_class=4),
|
||||
data_gen_fn=lambda: dict(x=torch.rand(4, 1, 10, IN_FEATURES)),
|
||||
output_transform_fn=single_output_transform_fn)
|
||||
model_zoo.register(
|
||||
name="torchaudio_deepspeech",
|
||||
model_fn=lambda: tm.DeepSpeech(IN_FEATURES, n_hidden=128, n_class=4),
|
||||
data_gen_fn=lambda: dict(x=torch.rand(4, 1, 10, IN_FEATURES)),
|
||||
output_transform_fn=single_output_transform_fn,
|
||||
)
|
||||
|
||||
|
||||
def emformer_data_gen_fn():
|
||||
@@ -50,21 +57,26 @@ def emformer_data_gen_fn():
|
||||
|
||||
|
||||
model_zoo.register(
|
||||
name='torchaudio_emformer',
|
||||
name="torchaudio_emformer",
|
||||
model_fn=lambda: tm.Emformer(input_dim=IN_FEATURES, num_heads=4, ffn_dim=128, num_layers=4, segment_length=4),
|
||||
data_gen_fn=emformer_data_gen_fn,
|
||||
output_transform_fn=transformer_output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
|
||||
model_zoo.register(name='torchaudio_wav2letter_waveform',
|
||||
model_fn=lambda: tm.Wav2Letter(input_type='waveform', num_features=40),
|
||||
data_gen_fn=lambda: dict(x=torch.rand(4, 40, 400)),
|
||||
output_transform_fn=single_output_transform_fn)
|
||||
model_zoo.register(
|
||||
name="torchaudio_wav2letter_waveform",
|
||||
model_fn=lambda: tm.Wav2Letter(input_type="waveform", num_features=40),
|
||||
data_gen_fn=lambda: dict(x=torch.rand(4, 40, 400)),
|
||||
output_transform_fn=single_output_transform_fn,
|
||||
)
|
||||
|
||||
model_zoo.register(name='torchaudio_wav2letter_mfcc',
|
||||
model_fn=lambda: tm.Wav2Letter(input_type='mfcc', num_features=40),
|
||||
data_gen_fn=lambda: dict(x=torch.rand(4, 40, 400)),
|
||||
output_transform_fn=single_output_transform_fn)
|
||||
model_zoo.register(
|
||||
name="torchaudio_wav2letter_mfcc",
|
||||
model_fn=lambda: tm.Wav2Letter(input_type="mfcc", num_features=40),
|
||||
data_gen_fn=lambda: dict(x=torch.rand(4, 40, 400)),
|
||||
output_transform_fn=single_output_transform_fn,
|
||||
)
|
||||
|
||||
|
||||
def wavernn_data_gen_fn():
|
||||
@@ -73,20 +85,24 @@ def wavernn_data_gen_fn():
|
||||
return dict(waveform=waveform, specgram=specgram)
|
||||
|
||||
|
||||
model_zoo.register(name='torchaudio_wavernn',
|
||||
model_fn=lambda: tm.WaveRNN(upsample_scales=[2, 2, 5],
|
||||
n_classes=N_CLASSES,
|
||||
hop_length=HOP_LENGTH,
|
||||
kernel_size=KERNEL_SIZE,
|
||||
n_freq=N_FREQ,
|
||||
n_res_block=2,
|
||||
n_rnn=64,
|
||||
n_fc=64,
|
||||
n_hidden=16,
|
||||
n_output=16),
|
||||
data_gen_fn=wavernn_data_gen_fn,
|
||||
output_transform_fn=single_output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(
|
||||
name="torchaudio_wavernn",
|
||||
model_fn=lambda: tm.WaveRNN(
|
||||
upsample_scales=[2, 2, 5],
|
||||
n_classes=N_CLASSES,
|
||||
hop_length=HOP_LENGTH,
|
||||
kernel_size=KERNEL_SIZE,
|
||||
n_freq=N_FREQ,
|
||||
n_res_block=2,
|
||||
n_rnn=64,
|
||||
n_fc=64,
|
||||
n_hidden=16,
|
||||
n_output=16,
|
||||
),
|
||||
data_gen_fn=wavernn_data_gen_fn,
|
||||
output_transform_fn=single_output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
|
||||
|
||||
def tacotron_data_gen_fn():
|
||||
@@ -97,17 +113,18 @@ def tacotron_data_gen_fn():
|
||||
token_lengths = max_text_length * torch.ones((n_batch,))
|
||||
mel_specgram = torch.rand(n_batch, N_MELS, max_mel_specgram_length)
|
||||
mel_specgram_lengths = max_mel_specgram_length * torch.ones((n_batch,))
|
||||
return dict(tokens=tokens,
|
||||
token_lengths=token_lengths,
|
||||
mel_specgram=mel_specgram,
|
||||
mel_specgram_lengths=mel_specgram_lengths)
|
||||
return dict(
|
||||
tokens=tokens, token_lengths=token_lengths, mel_specgram=mel_specgram, mel_specgram_lengths=mel_specgram_lengths
|
||||
)
|
||||
|
||||
|
||||
model_zoo.register(name='torchaudio_tacotron',
|
||||
model_fn=lambda: tm.Tacotron2(n_mels=N_MELS),
|
||||
data_gen_fn=tacotron_data_gen_fn,
|
||||
output_transform_fn=lambda outputs: dict(summed_output=sum(x.sum() for x in outputs)),
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(
|
||||
name="torchaudio_tacotron",
|
||||
model_fn=lambda: tm.Tacotron2(n_mels=N_MELS),
|
||||
data_gen_fn=tacotron_data_gen_fn,
|
||||
output_transform_fn=lambda outputs: dict(summed_output=sum(x.sum() for x in outputs)),
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
|
||||
|
||||
def wav2vec_data_gen_fn():
|
||||
@@ -117,14 +134,18 @@ def wav2vec_data_gen_fn():
|
||||
return dict(waveforms=waveforms, lengths=lengths)
|
||||
|
||||
|
||||
model_zoo.register(name='torchaudio_wav2vec2_base',
|
||||
model_fn=partial(tm.wav2vec2_base, encoder_layer_drop=0.0),
|
||||
data_gen_fn=wav2vec_data_gen_fn,
|
||||
output_transform_fn=transformer_output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(
|
||||
name="torchaudio_wav2vec2_base",
|
||||
model_fn=partial(tm.wav2vec2_base, encoder_layer_drop=0.0),
|
||||
data_gen_fn=wav2vec_data_gen_fn,
|
||||
output_transform_fn=transformer_output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
|
||||
model_zoo.register(name='torchaudio_hubert_base',
|
||||
model_fn=tm.hubert_base,
|
||||
data_gen_fn=wav2vec_data_gen_fn,
|
||||
output_transform_fn=transformer_output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(
|
||||
name="torchaudio_hubert_base",
|
||||
model_fn=tm.hubert_base,
|
||||
data_gen_fn=wav2vec_data_gen_fn,
|
||||
output_transform_fn=transformer_output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
|
@@ -1,4 +1,3 @@
|
||||
from collections import namedtuple
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
@@ -7,7 +6,7 @@ from torchrec.modules.embedding_configs import EmbeddingBagConfig
|
||||
from torchrec.modules.embedding_modules import EmbeddingBagCollection
|
||||
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor
|
||||
|
||||
from ..registry import ModelAttribute, model_zoo
|
||||
from ..registry import model_zoo
|
||||
|
||||
BATCH = 2
|
||||
SHAPE = 10
|
||||
@@ -20,9 +19,9 @@ def gen_kt():
|
||||
|
||||
# KeyedJaggedTensor
|
||||
def gen_kjt():
|
||||
KJT = KeyedJaggedTensor.from_offsets_sync(keys=["f1", "f2"],
|
||||
values=torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]),
|
||||
offsets=torch.tensor([0, 2, 4, 6, 8]))
|
||||
KJT = KeyedJaggedTensor.from_offsets_sync(
|
||||
keys=["f1", "f2"], values=torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]), offsets=torch.tensor([0, 2, 4, 6, 8])
|
||||
)
|
||||
return KJT
|
||||
|
||||
|
||||
@@ -68,7 +67,7 @@ def get_ebc():
|
||||
# EmbeddingBagCollection
|
||||
eb1_config = EmbeddingBagConfig(name="t1", embedding_dim=SHAPE, num_embeddings=SHAPE, feature_names=["f1"])
|
||||
eb2_config = EmbeddingBagConfig(name="t2", embedding_dim=SHAPE, num_embeddings=SHAPE, feature_names=["f2"])
|
||||
return EmbeddingBagCollection(tables=[eb1_config, eb2_config], device=torch.device('cpu'))
|
||||
return EmbeddingBagCollection(tables=[eb1_config, eb2_config], device=torch.device("cpu"))
|
||||
|
||||
|
||||
def sparse_arch_model_fn():
|
||||
@@ -91,52 +90,69 @@ def dlrm_sparsearch_model_fn():
|
||||
return dlrm.SparseArch(ebc)
|
||||
|
||||
|
||||
model_zoo.register(name='deepfm_densearch',
|
||||
model_fn=partial(deepfm.DenseArch, SHAPE, SHAPE, SHAPE),
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(
|
||||
name="deepfm_densearch",
|
||||
model_fn=partial(deepfm.DenseArch, SHAPE, SHAPE, SHAPE),
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
)
|
||||
|
||||
model_zoo.register(name='deepfm_interactionarch',
|
||||
model_fn=partial(deepfm.FMInteractionArch, SHAPE * 3, ["f1", "f2"], SHAPE),
|
||||
data_gen_fn=interaction_arch_data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(
|
||||
name="deepfm_interactionarch",
|
||||
model_fn=partial(deepfm.FMInteractionArch, SHAPE * 3, ["f1", "f2"], SHAPE),
|
||||
data_gen_fn=interaction_arch_data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
)
|
||||
|
||||
model_zoo.register(name='deepfm_overarch',
|
||||
model_fn=partial(deepfm.OverArch, SHAPE),
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(
|
||||
name="deepfm_overarch",
|
||||
model_fn=partial(deepfm.OverArch, SHAPE),
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
)
|
||||
|
||||
model_zoo.register(name='deepfm_simpledeepfmnn',
|
||||
model_fn=simple_deep_fmnn_model_fn,
|
||||
data_gen_fn=simple_dfm_data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(
|
||||
name="deepfm_simpledeepfmnn",
|
||||
model_fn=simple_deep_fmnn_model_fn,
|
||||
data_gen_fn=simple_dfm_data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
)
|
||||
|
||||
model_zoo.register(name='deepfm_sparsearch',
|
||||
model_fn=sparse_arch_model_fn,
|
||||
data_gen_fn=sparse_arch_data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(
|
||||
name="deepfm_sparsearch",
|
||||
model_fn=sparse_arch_model_fn,
|
||||
data_gen_fn=sparse_arch_data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
)
|
||||
|
||||
model_zoo.register(name='dlrm',
|
||||
model_fn=dlrm_model_fn,
|
||||
data_gen_fn=simple_dfm_data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(
|
||||
name="dlrm", model_fn=dlrm_model_fn, data_gen_fn=simple_dfm_data_gen_fn, output_transform_fn=output_transform_fn
|
||||
)
|
||||
|
||||
model_zoo.register(name='dlrm_densearch',
|
||||
model_fn=partial(dlrm.DenseArch, SHAPE, [SHAPE, SHAPE]),
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(
|
||||
name="dlrm_densearch",
|
||||
model_fn=partial(dlrm.DenseArch, SHAPE, [SHAPE, SHAPE]),
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
)
|
||||
|
||||
model_zoo.register(name='dlrm_interactionarch',
|
||||
model_fn=partial(dlrm.InteractionArch, 2),
|
||||
data_gen_fn=interaction_arch_data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(
|
||||
name="dlrm_interactionarch",
|
||||
model_fn=partial(dlrm.InteractionArch, 2),
|
||||
data_gen_fn=interaction_arch_data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
)
|
||||
|
||||
model_zoo.register(name='dlrm_overarch',
|
||||
model_fn=partial(dlrm.OverArch, SHAPE, [5, 1]),
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(
|
||||
name="dlrm_overarch",
|
||||
model_fn=partial(dlrm.OverArch, SHAPE, [5, 1]),
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
)
|
||||
|
||||
model_zoo.register(name='dlrm_sparsearch',
|
||||
model_fn=dlrm_sparsearch_model_fn,
|
||||
data_gen_fn=sparse_arch_data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(
|
||||
name="dlrm_sparsearch",
|
||||
model_fn=dlrm_sparsearch_model_fn,
|
||||
data_gen_fn=sparse_arch_data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
)
|
||||
|
@@ -1,5 +1,3 @@
|
||||
from collections import namedtuple
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
import torchvision.models as tm
|
||||
@@ -29,103 +27,133 @@ def swin_s():
|
||||
depths=[2, 2, 6, 2],
|
||||
num_heads=[3, 6, 12, 24],
|
||||
window_size=[7, 7],
|
||||
stochastic_depth_prob=0, # it is originally 0.2, but we set it to 0 to make it deterministic
|
||||
stochastic_depth_prob=0, # it is originally 0.2, but we set it to 0 to make it deterministic
|
||||
weights=weights,
|
||||
progress=progress,
|
||||
)
|
||||
|
||||
|
||||
# special output transform fn
|
||||
google_net_output_transform_fn = lambda x: dict(output=sum(x)) if isinstance(x, torchvision.models.GoogLeNetOutputs
|
||||
) else dict(output=x)
|
||||
swin_s_output_output_transform_fn = lambda x: {f'output{idx}': val
|
||||
for idx, val in enumerate(x)} if isinstance(x, tuple) else dict(output=x)
|
||||
inception_v3_output_transform_fn = lambda x: dict(output=sum(x)) if isinstance(x, torchvision.models.InceptionOutputs
|
||||
) else dict(output=x)
|
||||
google_net_output_transform_fn = (
|
||||
lambda x: dict(output=sum(x)) if isinstance(x, torchvision.models.GoogLeNetOutputs) else dict(output=x)
|
||||
)
|
||||
swin_s_output_output_transform_fn = (
|
||||
lambda x: {f"output{idx}": val for idx, val in enumerate(x)} if isinstance(x, tuple) else dict(output=x)
|
||||
)
|
||||
inception_v3_output_transform_fn = (
|
||||
lambda x: dict(output=sum(x)) if isinstance(x, torchvision.models.InceptionOutputs) else dict(output=x)
|
||||
)
|
||||
|
||||
model_zoo.register(name='torchvision_alexnet',
|
||||
model_fn=tm.alexnet,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(name='torchvision_densenet121',
|
||||
model_fn=tm.densenet121,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(name='torchvision_efficientnet_b0',
|
||||
model_fn=tm.efficientnet_b0,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_stochastic_depth_prob=True))
|
||||
model_zoo.register(name='torchvision_googlenet',
|
||||
model_fn=tm.googlenet,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=google_net_output_transform_fn)
|
||||
model_zoo.register(name='torchvision_inception_v3',
|
||||
model_fn=tm.inception_v3,
|
||||
data_gen_fn=inception_v3_data_gen_fn,
|
||||
output_transform_fn=inception_v3_output_transform_fn)
|
||||
model_zoo.register(name='torchvision_mobilenet_v2',
|
||||
model_fn=tm.mobilenet_v2,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(name='torchvision_mobilenet_v3_small',
|
||||
model_fn=tm.mobilenet_v3_small,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(name='torchvision_mnasnet0_5',
|
||||
model_fn=tm.mnasnet0_5,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(name='torchvision_resnet18',
|
||||
model_fn=tm.resnet18,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(name='torchvision_regnet_x_16gf',
|
||||
model_fn=tm.regnet_x_16gf,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(name='torchvision_resnext50_32x4d',
|
||||
model_fn=tm.resnext50_32x4d,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(name='torchvision_shufflenet_v2_x0_5',
|
||||
model_fn=tm.shufflenet_v2_x0_5,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(name='torchvision_squeezenet1_0',
|
||||
model_fn=tm.squeezenet1_0,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(
|
||||
name="torchvision_alexnet", model_fn=tm.alexnet, data_gen_fn=data_gen_fn, output_transform_fn=output_transform_fn
|
||||
)
|
||||
model_zoo.register(
|
||||
name="torchvision_densenet121",
|
||||
model_fn=tm.densenet121,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
)
|
||||
model_zoo.register(
|
||||
name="torchvision_efficientnet_b0",
|
||||
model_fn=tm.efficientnet_b0,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_stochastic_depth_prob=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
name="torchvision_googlenet",
|
||||
model_fn=tm.googlenet,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=google_net_output_transform_fn,
|
||||
)
|
||||
model_zoo.register(
|
||||
name="torchvision_inception_v3",
|
||||
model_fn=tm.inception_v3,
|
||||
data_gen_fn=inception_v3_data_gen_fn,
|
||||
output_transform_fn=inception_v3_output_transform_fn,
|
||||
)
|
||||
model_zoo.register(
|
||||
name="torchvision_mobilenet_v2",
|
||||
model_fn=tm.mobilenet_v2,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
)
|
||||
model_zoo.register(
|
||||
name="torchvision_mobilenet_v3_small",
|
||||
model_fn=tm.mobilenet_v3_small,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
)
|
||||
model_zoo.register(
|
||||
name="torchvision_mnasnet0_5",
|
||||
model_fn=tm.mnasnet0_5,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
)
|
||||
model_zoo.register(
|
||||
name="torchvision_resnet18", model_fn=tm.resnet18, data_gen_fn=data_gen_fn, output_transform_fn=output_transform_fn
|
||||
)
|
||||
model_zoo.register(
|
||||
name="torchvision_regnet_x_16gf",
|
||||
model_fn=tm.regnet_x_16gf,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
)
|
||||
model_zoo.register(
|
||||
name="torchvision_resnext50_32x4d",
|
||||
model_fn=tm.resnext50_32x4d,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
)
|
||||
model_zoo.register(
|
||||
name="torchvision_shufflenet_v2_x0_5",
|
||||
model_fn=tm.shufflenet_v2_x0_5,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
)
|
||||
model_zoo.register(
|
||||
name="torchvision_squeezenet1_0",
|
||||
model_fn=tm.squeezenet1_0,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
)
|
||||
|
||||
model_zoo.register(name='torchvision_vgg11',
|
||||
model_fn=tm.vgg11,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(name='torchvision_wide_resnet50_2',
|
||||
model_fn=tm.wide_resnet50_2,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(
|
||||
name="torchvision_vgg11", model_fn=tm.vgg11, data_gen_fn=data_gen_fn, output_transform_fn=output_transform_fn
|
||||
)
|
||||
model_zoo.register(
|
||||
name="torchvision_wide_resnet50_2",
|
||||
model_fn=tm.wide_resnet50_2,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
)
|
||||
|
||||
if version.parse(torchvision.__version__) >= version.parse('0.12.0'):
|
||||
model_zoo.register(name='torchvision_vit_b_16',
|
||||
model_fn=tm.vit_b_16,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn)
|
||||
model_zoo.register(name='torchvision_convnext_base',
|
||||
model_fn=tm.convnext_base,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_stochastic_depth_prob=True))
|
||||
|
||||
if version.parse(torchvision.__version__) >= version.parse('0.13.0'):
|
||||
if version.parse(torchvision.__version__) >= version.parse("0.12.0"):
|
||||
model_zoo.register(
|
||||
name='torchvision_swin_s',
|
||||
name="torchvision_vit_b_16",
|
||||
model_fn=tm.vit_b_16,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
)
|
||||
model_zoo.register(
|
||||
name="torchvision_convnext_base",
|
||||
model_fn=tm.convnext_base,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_stochastic_depth_prob=True),
|
||||
)
|
||||
|
||||
if version.parse(torchvision.__version__) >= version.parse("0.13.0"):
|
||||
model_zoo.register(
|
||||
name="torchvision_swin_s",
|
||||
model_fn=swin_s,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=swin_s_output_output_transform_fn,
|
||||
)
|
||||
model_zoo.register(name='torchvision_efficientnet_v2_s',
|
||||
model_fn=tm.efficientnet_v2_s,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_stochastic_depth_prob=True))
|
||||
model_zoo.register(
|
||||
name="torchvision_efficientnet_v2_s",
|
||||
model_fn=tm.efficientnet_v2_s,
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_stochastic_depth_prob=True),
|
||||
)
|
||||
|
@@ -19,44 +19,52 @@ def data_gen_fn():
|
||||
|
||||
def data_gen_for_pretrain():
|
||||
inputs = data_gen_fn()
|
||||
inputs['labels'] = inputs['input_ids'].clone()
|
||||
inputs['sentence_order_label'] = torch.zeros(BATCH_SIZE, dtype=torch.int64)
|
||||
inputs["labels"] = inputs["input_ids"].clone()
|
||||
inputs["sentence_order_label"] = torch.zeros(BATCH_SIZE, dtype=torch.int64)
|
||||
return inputs
|
||||
|
||||
|
||||
output_transform_fn = lambda x: x
|
||||
|
||||
config = transformers.AlbertConfig(embedding_size=128,
|
||||
hidden_size=128,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=256)
|
||||
config = transformers.AlbertConfig(
|
||||
embedding_size=128, hidden_size=128, num_hidden_layers=2, num_attention_heads=4, intermediate_size=256
|
||||
)
|
||||
|
||||
model_zoo.register(name='transformers_albert',
|
||||
model_fn=lambda: transformers.AlbertModel(config, add_pooling_layer=False),
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_albert_for_pretraining',
|
||||
model_fn=lambda: transformers.AlbertForPreTraining(config),
|
||||
data_gen_fn=data_gen_for_pretrain,
|
||||
output_transform_fn=lambda x: dict(loss=x.loss),
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_albert_for_masked_lm',
|
||||
model_fn=lambda: transformers.AlbertForMaskedLM(config),
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_albert_for_sequence_classification',
|
||||
model_fn=lambda: transformers.AlbertForSequenceClassification(config),
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_albert_for_token_classification',
|
||||
model_fn=lambda: transformers.AlbertForTokenClassification(config),
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(
|
||||
name="transformers_albert",
|
||||
model_fn=lambda: transformers.AlbertModel(config, add_pooling_layer=False),
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
name="transformers_albert_for_pretraining",
|
||||
model_fn=lambda: transformers.AlbertForPreTraining(config),
|
||||
data_gen_fn=data_gen_for_pretrain,
|
||||
output_transform_fn=lambda x: dict(loss=x.loss),
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
name="transformers_albert_for_masked_lm",
|
||||
model_fn=lambda: transformers.AlbertForMaskedLM(config),
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
name="transformers_albert_for_sequence_classification",
|
||||
model_fn=lambda: transformers.AlbertForSequenceClassification(config),
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
name="transformers_albert_for_token_classification",
|
||||
model_fn=lambda: transformers.AlbertForTokenClassification(config),
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
|
||||
# ===============================
|
||||
# Register multi-sentence ALBERT
|
||||
@@ -80,13 +88,17 @@ def data_gen_for_mcq():
|
||||
return encoding
|
||||
|
||||
|
||||
model_zoo.register(name='transformers_albert_for_question_answering',
|
||||
model_fn=lambda: transformers.AlbertForQuestionAnswering(config),
|
||||
data_gen_fn=data_gen_for_qa,
|
||||
output_transform_fn=output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_albert_for_multiple_choice',
|
||||
model_fn=lambda: transformers.AlbertForMultipleChoice(config),
|
||||
data_gen_fn=data_gen_for_mcq,
|
||||
output_transform_fn=output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(
|
||||
name="transformers_albert_for_question_answering",
|
||||
model_fn=lambda: transformers.AlbertForQuestionAnswering(config),
|
||||
data_gen_fn=data_gen_for_qa,
|
||||
output_transform_fn=output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
name="transformers_albert_for_multiple_choice",
|
||||
model_fn=lambda: transformers.AlbertForMultipleChoice(config),
|
||||
data_gen_fn=data_gen_for_mcq,
|
||||
output_transform_fn=output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
|
@@ -28,7 +28,7 @@ def data_gen_for_lm():
|
||||
# LM data gen
|
||||
# the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels`
|
||||
data = data_gen()
|
||||
data['labels'] = data['input_ids'].clone()
|
||||
data["labels"] = data["input_ids"].clone()
|
||||
return data
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@ def data_gen_for_pretraining():
|
||||
# pretraining data gen
|
||||
# `next_sentence_label` is the label for next sentence prediction, 0 or 1
|
||||
data = data_gen_for_lm()
|
||||
data['next_sentence_label'] = torch.tensor([1], dtype=torch.int64)
|
||||
data["next_sentence_label"] = torch.tensor([1], dtype=torch.int64)
|
||||
return data
|
||||
|
||||
|
||||
@@ -44,7 +44,7 @@ def data_gen_for_sequence_classification():
|
||||
# sequence classification data gen
|
||||
# `labels` is the label for sequence classification, 0 or 1
|
||||
data = data_gen()
|
||||
data['labels'] = torch.tensor([1], dtype=torch.int64)
|
||||
data["labels"] = torch.tensor([1], dtype=torch.int64)
|
||||
return data
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ def data_gen_for_token_classification():
|
||||
# token classification data gen
|
||||
# `labels` is the type not the token id for token classification, 0 or 1
|
||||
data = data_gen()
|
||||
data['labels'] = torch.tensor([[1, 0, 0, 0, 0, 0, 0, 0]], dtype=torch.int64)
|
||||
data["labels"] = torch.tensor([[1, 0, 0, 0, 0, 0, 0, 0]], dtype=torch.int64)
|
||||
return data
|
||||
|
||||
|
||||
@@ -67,32 +67,276 @@ def data_gen_for_mcq():
|
||||
# data = tokenizer([prompt, prompt], [choice0, choice1], return_tensors="pt", padding=True)
|
||||
# data = {k: v.unsqueeze(0) for k, v in encoding.items()}
|
||||
# data['labels'] = torch.tensor([0], dtype=torch.int64)
|
||||
input_ids = torch.tensor([[[
|
||||
101, 1999, 3304, 1010, 10733, 2366, 1999, 5337, 10906, 1010, 2107, 2004, 2012, 1037, 4825, 1010, 2003, 3591,
|
||||
4895, 14540, 6610, 2094, 1012, 102, 2009, 2003, 8828, 2007, 1037, 9292, 1998, 1037, 5442, 1012, 102, 102, 5442,
|
||||
1012, 102, 102
|
||||
],
|
||||
[
|
||||
101, 1999, 3304, 1010, 10733, 2366, 1999, 5337, 10906, 1010, 2107, 2004, 2012, 1037,
|
||||
4825, 1010, 2003, 3591, 4895, 14540, 6610, 2094, 1012, 102, 2009, 2003, 8828, 2096,
|
||||
2218, 1999, 1996, 2192, 1012, 102, 0, 0, 1012, 102, 0, 0
|
||||
]]])
|
||||
token_type_ids = torch.tensor([[[
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||
1, 1, 1
|
||||
],
|
||||
[
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1,
|
||||
1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0
|
||||
]]])
|
||||
attention_mask = torch.tensor([[[
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||
1, 1, 1
|
||||
],
|
||||
[
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||
1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0
|
||||
]]])
|
||||
input_ids = torch.tensor(
|
||||
[
|
||||
[
|
||||
[
|
||||
101,
|
||||
1999,
|
||||
3304,
|
||||
1010,
|
||||
10733,
|
||||
2366,
|
||||
1999,
|
||||
5337,
|
||||
10906,
|
||||
1010,
|
||||
2107,
|
||||
2004,
|
||||
2012,
|
||||
1037,
|
||||
4825,
|
||||
1010,
|
||||
2003,
|
||||
3591,
|
||||
4895,
|
||||
14540,
|
||||
6610,
|
||||
2094,
|
||||
1012,
|
||||
102,
|
||||
2009,
|
||||
2003,
|
||||
8828,
|
||||
2007,
|
||||
1037,
|
||||
9292,
|
||||
1998,
|
||||
1037,
|
||||
5442,
|
||||
1012,
|
||||
102,
|
||||
102,
|
||||
5442,
|
||||
1012,
|
||||
102,
|
||||
102,
|
||||
],
|
||||
[
|
||||
101,
|
||||
1999,
|
||||
3304,
|
||||
1010,
|
||||
10733,
|
||||
2366,
|
||||
1999,
|
||||
5337,
|
||||
10906,
|
||||
1010,
|
||||
2107,
|
||||
2004,
|
||||
2012,
|
||||
1037,
|
||||
4825,
|
||||
1010,
|
||||
2003,
|
||||
3591,
|
||||
4895,
|
||||
14540,
|
||||
6610,
|
||||
2094,
|
||||
1012,
|
||||
102,
|
||||
2009,
|
||||
2003,
|
||||
8828,
|
||||
2096,
|
||||
2218,
|
||||
1999,
|
||||
1996,
|
||||
2192,
|
||||
1012,
|
||||
102,
|
||||
0,
|
||||
0,
|
||||
1012,
|
||||
102,
|
||||
0,
|
||||
0,
|
||||
],
|
||||
]
|
||||
]
|
||||
)
|
||||
token_type_ids = torch.tensor(
|
||||
[
|
||||
[
|
||||
[
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
],
|
||||
[
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
],
|
||||
]
|
||||
]
|
||||
)
|
||||
attention_mask = torch.tensor(
|
||||
[
|
||||
[
|
||||
[
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
],
|
||||
[
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
],
|
||||
]
|
||||
]
|
||||
)
|
||||
labels = torch.tensor([0], dtype=torch.int64)
|
||||
|
||||
return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, labels=labels)
|
||||
@@ -103,9 +347,9 @@ def data_gen_for_qa():
|
||||
# no need for labels and use start and end position instead
|
||||
data = data_gen()
|
||||
start_positions = torch.tensor([0], dtype=torch.int64)
|
||||
data['start_positions'] = start_positions
|
||||
data["start_positions"] = start_positions
|
||||
end_positions = torch.tensor([1], dtype=torch.int64)
|
||||
data['end_positions'] = end_positions
|
||||
data["end_positions"] = end_positions
|
||||
return data
|
||||
|
||||
|
||||
@@ -114,69 +358,90 @@ output_transform_fn = lambda x: x
|
||||
|
||||
# define loss funciton
|
||||
|
||||
loss_fn_for_bert_model = lambda x: torch.nn.functional.mse_loss(x.last_hidden_state, torch.ones_like(x.last_hidden_state
|
||||
))
|
||||
loss_fn_for_bert_model = lambda x: torch.nn.functional.mse_loss(
|
||||
x.last_hidden_state, torch.ones_like(x.last_hidden_state)
|
||||
)
|
||||
loss_fn = lambda x: x.loss
|
||||
|
||||
config = transformers.BertConfig(hidden_size=128,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=256,
|
||||
hidden_dropout_prob=0,
|
||||
attention_probs_dropout_prob=0)
|
||||
config = transformers.BertConfig(
|
||||
hidden_size=128,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=256,
|
||||
hidden_dropout_prob=0,
|
||||
attention_probs_dropout_prob=0,
|
||||
)
|
||||
|
||||
# register the BERT variants
|
||||
model_zoo.register(name='transformers_bert',
|
||||
model_fn=lambda: transformers.BertModel(config, add_pooling_layer=False),
|
||||
data_gen_fn=data_gen,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_bert_model,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_bert_for_pretraining',
|
||||
model_fn=lambda: transformers.BertForPreTraining(config),
|
||||
data_gen_fn=data_gen_for_pretraining,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_bert_lm_head_model',
|
||||
model_fn=lambda: transformers.BertLMHeadModel(config),
|
||||
data_gen_fn=data_gen_for_lm,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_bert_for_masked_lm',
|
||||
model_fn=lambda: transformers.BertForMaskedLM(config),
|
||||
data_gen_fn=data_gen_for_lm,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_bert_for_sequence_classification',
|
||||
model_fn=lambda: transformers.BertForSequenceClassification(config),
|
||||
data_gen_fn=data_gen_for_sequence_classification,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_bert_for_token_classification',
|
||||
model_fn=lambda: transformers.BertForTokenClassification(config),
|
||||
data_gen_fn=data_gen_for_token_classification,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_bert_for_next_sentence',
|
||||
model_fn=lambda: transformers.BertForNextSentencePrediction(config),
|
||||
data_gen_fn=data_gen_for_sequence_classification,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_bert_for_mcq',
|
||||
model_fn=lambda: transformers.BertForMultipleChoice(config),
|
||||
data_gen_fn=data_gen_for_mcq,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_bert_for_question_answering',
|
||||
model_fn=lambda: transformers.BertForQuestionAnswering(config),
|
||||
data_gen_fn=data_gen_for_qa,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(
|
||||
name="transformers_bert",
|
||||
model_fn=lambda: transformers.BertModel(config, add_pooling_layer=False),
|
||||
data_gen_fn=data_gen,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_bert_model,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
name="transformers_bert_for_pretraining",
|
||||
model_fn=lambda: transformers.BertForPreTraining(config),
|
||||
data_gen_fn=data_gen_for_pretraining,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
name="transformers_bert_lm_head_model",
|
||||
model_fn=lambda: transformers.BertLMHeadModel(config),
|
||||
data_gen_fn=data_gen_for_lm,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
name="transformers_bert_for_masked_lm",
|
||||
model_fn=lambda: transformers.BertForMaskedLM(config),
|
||||
data_gen_fn=data_gen_for_lm,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
name="transformers_bert_for_sequence_classification",
|
||||
model_fn=lambda: transformers.BertForSequenceClassification(config),
|
||||
data_gen_fn=data_gen_for_sequence_classification,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
name="transformers_bert_for_token_classification",
|
||||
model_fn=lambda: transformers.BertForTokenClassification(config),
|
||||
data_gen_fn=data_gen_for_token_classification,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
name="transformers_bert_for_next_sentence",
|
||||
model_fn=lambda: transformers.BertForNextSentencePrediction(config),
|
||||
data_gen_fn=data_gen_for_sequence_classification,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
name="transformers_bert_for_mcq",
|
||||
model_fn=lambda: transformers.BertForMultipleChoice(config),
|
||||
data_gen_fn=data_gen_for_mcq,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
name="transformers_bert_for_question_answering",
|
||||
model_fn=lambda: transformers.BertForQuestionAnswering(config),
|
||||
data_gen_fn=data_gen_for_qa,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
|
@@ -47,16 +47,20 @@ config.qformer_config.hidden_dropout_prob = 0
|
||||
config.text_config.dropout = 0
|
||||
|
||||
# register the blip2 variants
|
||||
model_zoo.register(name='transformers_blip2',
|
||||
model_fn=lambda: transformers.Blip2Model(config),
|
||||
data_gen_fn=data_gen,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_blip2_model,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(
|
||||
name="transformers_blip2",
|
||||
model_fn=lambda: transformers.Blip2Model(config),
|
||||
data_gen_fn=data_gen,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_blip2_model,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
|
||||
model_zoo.register(name='transformers_blip2_conditional_gerneration',
|
||||
model_fn=lambda: transformers.Blip2ForConditionalGeneration(config),
|
||||
data_gen_fn=data_gen,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_blip2_model,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(
|
||||
name="transformers_blip2_conditional_gerneration",
|
||||
model_fn=lambda: transformers.Blip2ForConditionalGeneration(config),
|
||||
data_gen_fn=data_gen,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_blip2_model,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
|
@@ -25,7 +25,7 @@ def data_gen_for_lm():
|
||||
# LM data gen
|
||||
# the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels`
|
||||
data = data_gen()
|
||||
data['labels'] = data['input_ids'].clone()
|
||||
data["labels"] = data["input_ids"].clone()
|
||||
return data
|
||||
|
||||
|
||||
@@ -33,14 +33,14 @@ def data_gen_for_token_classification():
|
||||
# token classification data gen
|
||||
# `labels` is the type not the token id for token classification, 0 or 1
|
||||
data = data_gen()
|
||||
data['labels'] = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0]], dtype=torch.int64)
|
||||
data["labels"] = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0]], dtype=torch.int64)
|
||||
return data
|
||||
|
||||
|
||||
def data_gen_for_sequence_classification():
|
||||
# sequence classification data gen
|
||||
data = data_gen()
|
||||
data['labels'] = torch.tensor([0], dtype=torch.int64)
|
||||
data["labels"] = torch.tensor([0], dtype=torch.int64)
|
||||
return data
|
||||
|
||||
|
||||
@@ -54,62 +54,69 @@ def data_gen_for_question_answering():
|
||||
|
||||
input_ids = torch.tensor(
|
||||
[[57647, 1620, 23967, 620, 107373, 34, 91514, 620, 107373, 1620, 267, 35378, 48946, 18161, 48946, 18161]],
|
||||
dtype=torch.int64)
|
||||
dtype=torch.int64,
|
||||
)
|
||||
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
|
||||
start_positions = torch.tensor([1], dtype=torch.int64)
|
||||
end_positions = torch.tensor([10], dtype=torch.int64)
|
||||
return dict(input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
start_positions=start_positions,
|
||||
end_positions=end_positions)
|
||||
return dict(
|
||||
input_ids=input_ids, attention_mask=attention_mask, start_positions=start_positions, end_positions=end_positions
|
||||
)
|
||||
|
||||
|
||||
# define output transform function
|
||||
output_transform_fn = lambda x: x
|
||||
|
||||
# define loss function
|
||||
loss_fn_for_bloom_model = lambda x: torch.nn.functional.mse_loss(x.last_hidden_state,
|
||||
torch.ones_like(x.last_hidden_state))
|
||||
loss_fn_for_bloom_model = lambda x: torch.nn.functional.mse_loss(
|
||||
x.last_hidden_state, torch.ones_like(x.last_hidden_state)
|
||||
)
|
||||
loss_fn_for_causal_lm = lambda x: x.loss
|
||||
loss_fn_for_classification = lambda x: x.loss
|
||||
loss_fn_for_question_answering = lambda x: x.loss
|
||||
|
||||
config = transformers.BloomConfig(n_layer=2,
|
||||
n_head=4,
|
||||
vocab_size=250880,
|
||||
hidden_dropout=0,
|
||||
attention_dropout=0,
|
||||
hidden_size=64,
|
||||
pad_token_id=50256)
|
||||
config = transformers.BloomConfig(
|
||||
n_layer=2, n_head=4, vocab_size=250880, hidden_dropout=0, attention_dropout=0, hidden_size=64, pad_token_id=50256
|
||||
)
|
||||
|
||||
# register the following models
|
||||
model_zoo.register(name='transformers_bloom',
|
||||
model_fn=lambda: transformers.BloomModel(config),
|
||||
data_gen_fn=data_gen,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_bloom_model,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_bloom_for_causal_lm',
|
||||
model_fn=lambda: transformers.BloomForCausalLM(config),
|
||||
data_gen_fn=data_gen_for_lm,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_causal_lm,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_bloom_for_sequence_classification',
|
||||
model_fn=lambda: transformers.BloomForSequenceClassification(config),
|
||||
data_gen_fn=data_gen_for_sequence_classification,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_classification,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_bloom_for_token_classification',
|
||||
model_fn=lambda: transformers.BloomForTokenClassification(config),
|
||||
data_gen_fn=data_gen_for_token_classification,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_classification,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_bloom_for_question_answering',
|
||||
model_fn=lambda: transformers.BloomForQuestionAnswering(config),
|
||||
data_gen_fn=data_gen_for_question_answering,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_question_answering,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(
|
||||
name="transformers_bloom",
|
||||
model_fn=lambda: transformers.BloomModel(config),
|
||||
data_gen_fn=data_gen,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_bloom_model,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
name="transformers_bloom_for_causal_lm",
|
||||
model_fn=lambda: transformers.BloomForCausalLM(config),
|
||||
data_gen_fn=data_gen_for_lm,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_causal_lm,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
name="transformers_bloom_for_sequence_classification",
|
||||
model_fn=lambda: transformers.BloomForSequenceClassification(config),
|
||||
data_gen_fn=data_gen_for_sequence_classification,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_classification,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
name="transformers_bloom_for_token_classification",
|
||||
model_fn=lambda: transformers.BloomForTokenClassification(config),
|
||||
data_gen_fn=data_gen_for_token_classification,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_classification,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
name="transformers_bloom_for_question_answering",
|
||||
model_fn=lambda: transformers.BloomForQuestionAnswering(config),
|
||||
data_gen_fn=data_gen_for_question_answering,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_question_answering,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
|
@@ -1,5 +1,4 @@
|
||||
import torch
|
||||
import transformers
|
||||
|
||||
from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig
|
||||
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel
|
||||
@@ -21,8 +20,8 @@ def data_gen_for_conditional_generation():
|
||||
# token classification data gen
|
||||
# `labels` is the type not the token id for token classification, 0 or 1
|
||||
data = data_gen()
|
||||
labels = data['input_ids'].clone()
|
||||
data['labels'] = labels
|
||||
labels = data["input_ids"].clone()
|
||||
data["labels"] = labels
|
||||
return data
|
||||
|
||||
|
||||
@@ -30,29 +29,36 @@ def data_gen_for_conditional_generation():
|
||||
output_transform_fn = lambda x: x
|
||||
|
||||
# define loss function
|
||||
loss_fn_for_chatglm_model = lambda x: torch.nn.functional.mse_loss(x.last_hidden_state,
|
||||
torch.ones_like(x.last_hidden_state))
|
||||
loss_fn_for_chatglm_model = lambda x: torch.nn.functional.mse_loss(
|
||||
x.last_hidden_state, torch.ones_like(x.last_hidden_state)
|
||||
)
|
||||
loss_fn = lambda x: x.loss
|
||||
|
||||
config = ChatGLMConfig(num_layers=2,
|
||||
padded_vocab_size=65024,
|
||||
hidden_size=64,
|
||||
num_attention_heads=8,
|
||||
rmsnorm=True,
|
||||
original_rope=True,
|
||||
use_cache=True,
|
||||
torch_dtype=torch.float32)
|
||||
config = ChatGLMConfig(
|
||||
num_layers=2,
|
||||
padded_vocab_size=65024,
|
||||
hidden_size=64,
|
||||
num_attention_heads=8,
|
||||
rmsnorm=True,
|
||||
original_rope=True,
|
||||
use_cache=True,
|
||||
torch_dtype=torch.float32,
|
||||
)
|
||||
|
||||
model_zoo.register(name='transformers_chatglm',
|
||||
model_fn=lambda: ChatGLMModel(config, empty_init=False),
|
||||
data_gen_fn=data_gen,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_chatglm_model,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(
|
||||
name="transformers_chatglm",
|
||||
model_fn=lambda: ChatGLMModel(config, empty_init=False),
|
||||
data_gen_fn=data_gen,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_chatglm_model,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
|
||||
model_zoo.register(name="transformers_chatglm_for_conditional_generation",
|
||||
model_fn=lambda: ChatGLMForConditionalGeneration(config, empty_init=False),
|
||||
data_gen_fn=data_gen_for_conditional_generation,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(
|
||||
name="transformers_chatglm_for_conditional_generation",
|
||||
model_fn=lambda: ChatGLMForConditionalGeneration(config, empty_init=False),
|
||||
data_gen_fn=data_gen_for_conditional_generation,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
|
@@ -27,7 +27,7 @@ def data_gen_for_lm():
|
||||
# LM data gen
|
||||
# the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels`
|
||||
data = data_gen()
|
||||
data['labels'] = data['input_ids'].clone()
|
||||
data["labels"] = data["input_ids"].clone()
|
||||
return data
|
||||
|
||||
|
||||
@@ -36,9 +36,9 @@ def data_gen_for_question_answering():
|
||||
# `labels` is the type not the token id for token classification, 0 or 1
|
||||
data = data_gen()
|
||||
start_positions = torch.tensor([0], dtype=torch.int64)
|
||||
data['start_positions'] = start_positions
|
||||
data["start_positions"] = start_positions
|
||||
end_positions = torch.tensor([1], dtype=torch.int64)
|
||||
data['end_positions'] = end_positions
|
||||
data["end_positions"] = end_positions
|
||||
return data
|
||||
|
||||
|
||||
@@ -46,14 +46,14 @@ def data_gen_for_token_classification():
|
||||
# token classification data gen
|
||||
# `labels` is the type not the token id for token classification, 0 or 1
|
||||
data = data_gen()
|
||||
data['labels'] = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 1]], dtype=torch.int64)
|
||||
data["labels"] = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 1]], dtype=torch.int64)
|
||||
return data
|
||||
|
||||
|
||||
def data_gen_for_sequence_classification():
|
||||
# sequence classification data gen
|
||||
data = data_gen()
|
||||
data['labels'] = torch.tensor([1], dtype=torch.int64)
|
||||
data["labels"] = torch.tensor([1], dtype=torch.int64)
|
||||
return data
|
||||
|
||||
|
||||
@@ -62,7 +62,8 @@ def date_gen_for_double_heads():
|
||||
batch_size = 2
|
||||
input_ids = torch.tensor(
|
||||
[[15496, 11, 616, 3290, 318, 13779, 318, 13779], [15496, 11, 616, 3290, 318, 13779, 318, 13779]],
|
||||
dtype=torch.int64)
|
||||
dtype=torch.int64,
|
||||
)
|
||||
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
|
||||
mc_labels = torch.zeros(input_ids.shape[0], dtype=torch.int64)
|
||||
|
||||
@@ -85,58 +86,73 @@ def date_gen_for_double_heads():
|
||||
output_transform_fn = lambda x: x
|
||||
|
||||
# define loss function
|
||||
loss_fn_for_gpt2_model = lambda x: torch.nn.functional.mse_loss(x.last_hidden_state, torch.ones_like(x.last_hidden_state
|
||||
))
|
||||
loss_fn_for_gpt2_model = lambda x: torch.nn.functional.mse_loss(
|
||||
x.last_hidden_state, torch.ones_like(x.last_hidden_state)
|
||||
)
|
||||
loss_fn = lambda x: x.loss
|
||||
|
||||
config = transformers.GPT2Config(n_layer=2,
|
||||
n_head=4,
|
||||
vocab_size=50258,
|
||||
attn_pdrop=0,
|
||||
embd_pdrop=0,
|
||||
resid_pdrop=0,
|
||||
summary_first_dropout=0,
|
||||
hidden_dropout=0,
|
||||
problem_type="single_label_classification",
|
||||
pad_token_id=50256)
|
||||
config = transformers.GPT2Config(
|
||||
n_layer=2,
|
||||
n_head=4,
|
||||
vocab_size=50258,
|
||||
attn_pdrop=0,
|
||||
embd_pdrop=0,
|
||||
resid_pdrop=0,
|
||||
summary_first_dropout=0,
|
||||
hidden_dropout=0,
|
||||
problem_type="single_label_classification",
|
||||
pad_token_id=50256,
|
||||
)
|
||||
|
||||
config_for_token_classification = copy.deepcopy(config)
|
||||
config_for_token_classification.num_labels = 2
|
||||
|
||||
# register the following models
|
||||
model_zoo.register(name='transformers_gpt',
|
||||
model_fn=lambda: transformers.GPT2Model(config),
|
||||
data_gen_fn=data_gen,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_gpt2_model,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_gpt_lm',
|
||||
model_fn=lambda: transformers.GPT2LMHeadModel(config),
|
||||
data_gen_fn=data_gen_for_lm,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_gpt_double_heads',
|
||||
model_fn=lambda: transformers.GPT2DoubleHeadsModel(config),
|
||||
data_gen_fn=date_gen_for_double_heads,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=lambda x: x.loss + x.mc_loss,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_gpt_for_question_answering',
|
||||
model_fn=lambda: transformers.GPT2ForQuestionAnswering(config),
|
||||
data_gen_fn=data_gen_for_question_answering,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_gpt_for_token_classification',
|
||||
model_fn=lambda: transformers.GPT2ForTokenClassification(config_for_token_classification),
|
||||
data_gen_fn=data_gen_for_token_classification,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_gpt_for_sequence_classification',
|
||||
model_fn=lambda: transformers.GPT2ForSequenceClassification(config_for_token_classification),
|
||||
data_gen_fn=data_gen_for_sequence_classification,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(
|
||||
name="transformers_gpt",
|
||||
model_fn=lambda: transformers.GPT2Model(config),
|
||||
data_gen_fn=data_gen,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_gpt2_model,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
name="transformers_gpt_lm",
|
||||
model_fn=lambda: transformers.GPT2LMHeadModel(config),
|
||||
data_gen_fn=data_gen_for_lm,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
name="transformers_gpt_double_heads",
|
||||
model_fn=lambda: transformers.GPT2DoubleHeadsModel(config),
|
||||
data_gen_fn=date_gen_for_double_heads,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=lambda x: x.loss + x.mc_loss,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
name="transformers_gpt_for_question_answering",
|
||||
model_fn=lambda: transformers.GPT2ForQuestionAnswering(config),
|
||||
data_gen_fn=data_gen_for_question_answering,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
name="transformers_gpt_for_token_classification",
|
||||
model_fn=lambda: transformers.GPT2ForTokenClassification(config_for_token_classification),
|
||||
data_gen_fn=data_gen_for_token_classification,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
name="transformers_gpt_for_sequence_classification",
|
||||
model_fn=lambda: transformers.GPT2ForSequenceClassification(config_for_token_classification),
|
||||
data_gen_fn=data_gen_for_sequence_classification,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
|
@@ -4,7 +4,8 @@ import transformers
|
||||
from ..registry import ModelAttribute, model_zoo
|
||||
|
||||
try:
|
||||
from transformers import LlamaConfig, LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel
|
||||
from transformers import LlamaConfig
|
||||
|
||||
HAS_LLAMA = True
|
||||
except ImportError:
|
||||
HAS_LLAMA = False
|
||||
@@ -33,8 +34,8 @@ if HAS_LLAMA:
|
||||
# label is needed for casual lm
|
||||
def data_gen_for_casual_lm():
|
||||
data = data_gen()
|
||||
labels = data['input_ids'].clone()
|
||||
data['labels'] = labels
|
||||
labels = data["input_ids"].clone()
|
||||
data["labels"] = labels
|
||||
return data
|
||||
|
||||
# transform the output to a dict
|
||||
@@ -45,12 +46,14 @@ if HAS_LLAMA:
|
||||
loss_fn_for_casual_lm = lambda output: output.loss
|
||||
loss_fn_for_seq_classification = lambda output: output.logits.mean()
|
||||
|
||||
config = LlamaConfig(num_hidden_layers=4,
|
||||
hidden_size=128,
|
||||
intermediate_size=256,
|
||||
num_attention_heads=4,
|
||||
max_position_embeddings=128,
|
||||
num_labels=16)
|
||||
config = LlamaConfig(
|
||||
num_hidden_layers=4,
|
||||
hidden_size=128,
|
||||
intermediate_size=256,
|
||||
num_attention_heads=4,
|
||||
max_position_embeddings=128,
|
||||
num_labels=16,
|
||||
)
|
||||
|
||||
if hasattr(config, "pad_token_id"):
|
||||
config.pad_token_id = config.eos_token_id
|
||||
@@ -59,21 +62,27 @@ if HAS_LLAMA:
|
||||
# transformers.LlamaModel,
|
||||
# transformers.LlamaForCausalLM,
|
||||
# transformers.LlamaForSequenceClassification,
|
||||
model_zoo.register(name='transformers_llama',
|
||||
model_fn=lambda: transformers.LlamaModel(config),
|
||||
data_gen_fn=data_gen,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_llama_for_casual_lm',
|
||||
model_fn=lambda: transformers.LlamaForCausalLM(config),
|
||||
data_gen_fn=data_gen_for_casual_lm,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_casual_lm,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_llama_for_sequence_classification',
|
||||
model_fn=lambda: transformers.LlamaForSequenceClassification(config),
|
||||
data_gen_fn=data_gen,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_seq_classification,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(
|
||||
name="transformers_llama",
|
||||
model_fn=lambda: transformers.LlamaModel(config),
|
||||
data_gen_fn=data_gen,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
name="transformers_llama_for_casual_lm",
|
||||
model_fn=lambda: transformers.LlamaForCausalLM(config),
|
||||
data_gen_fn=data_gen_for_casual_lm,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_casual_lm,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
name="transformers_llama_for_sequence_classification",
|
||||
model_fn=lambda: transformers.LlamaForSequenceClassification(config),
|
||||
data_gen_fn=data_gen,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_seq_classification,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
|
@@ -20,8 +20,8 @@ def data_gen_for_causal_lm():
|
||||
# LM data gen
|
||||
# the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels`
|
||||
data = data_gen()
|
||||
labels = data['input_ids'].clone()
|
||||
data['labels'] = labels
|
||||
labels = data["input_ids"].clone()
|
||||
data["labels"] = labels
|
||||
return data
|
||||
|
||||
|
||||
@@ -29,8 +29,8 @@ def data_gen_for_sequence_classification():
|
||||
# LM data gen
|
||||
# the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels`
|
||||
data = data_gen()
|
||||
labels = data['input_ids'].clone()
|
||||
data['labels'] = torch.tensor([1])
|
||||
data["input_ids"].clone()
|
||||
data["labels"] = torch.tensor([1])
|
||||
return data
|
||||
|
||||
|
||||
@@ -38,14 +38,15 @@ def data_gen_for_question_answering():
|
||||
# LM data gen
|
||||
# the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels`
|
||||
data = data_gen()
|
||||
data['start_positions'] = torch.tensor([0])
|
||||
data['end_positions'] = torch.tensor([1])
|
||||
data["start_positions"] = torch.tensor([0])
|
||||
data["end_positions"] = torch.tensor([1])
|
||||
return data
|
||||
|
||||
|
||||
output_transform_fn = lambda x: x
|
||||
loss_fn_for_opt_model = lambda x: torch.nn.functional.mse_loss(x.last_hidden_state, torch.ones_like(x.last_hidden_state)
|
||||
)
|
||||
loss_fn_for_opt_model = lambda x: torch.nn.functional.mse_loss(
|
||||
x.last_hidden_state, torch.ones_like(x.last_hidden_state)
|
||||
)
|
||||
loss_fn_for_lm = lambda x: x.loss
|
||||
config = transformers.OPTConfig(
|
||||
hidden_size=128,
|
||||
@@ -57,24 +58,30 @@ config = transformers.OPTConfig(
|
||||
# register the following models
|
||||
# transformers.OPTModel,
|
||||
# transformers.OPTForCausalLM,
|
||||
model_zoo.register(name='transformers_opt',
|
||||
model_fn=lambda: transformers.OPTModel(config),
|
||||
data_gen_fn=data_gen,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_opt_model,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_opt_for_causal_lm',
|
||||
model_fn=lambda: transformers.OPTForCausalLM(config),
|
||||
data_gen_fn=data_gen_for_causal_lm,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_lm,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_opt_for_question_answering',
|
||||
model_fn=lambda: transformers.OPTForQuestionAnswering(config),
|
||||
data_gen_fn=data_gen_for_question_answering,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_lm,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(
|
||||
name="transformers_opt",
|
||||
model_fn=lambda: transformers.OPTModel(config),
|
||||
data_gen_fn=data_gen,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_opt_model,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
name="transformers_opt_for_causal_lm",
|
||||
model_fn=lambda: transformers.OPTForCausalLM(config),
|
||||
data_gen_fn=data_gen_for_causal_lm,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_lm,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
name="transformers_opt_for_question_answering",
|
||||
model_fn=lambda: transformers.OPTForQuestionAnswering(config),
|
||||
data_gen_fn=data_gen_for_question_answering,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_lm,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
|
||||
# TODO The loss and gradient check in the test are failing, to be fixed.
|
||||
# model_zoo.register(name='transformers_opt_for_sequence_classification',
|
||||
|
@@ -28,10 +28,12 @@ def data_gen():
|
||||
original_sizes = torch.tensor([[1764, 2646]], dtype=torch.int64)
|
||||
reshaped_input_sizes = torch.tensor([[683, 1024]], dtype=torch.int64)
|
||||
input_points = torch.tensor([[[[174.1497, 232.3129]]]], dtype=torch.float64)
|
||||
return dict(pixel_values=pixel_values,
|
||||
original_sizes=original_sizes,
|
||||
reshaped_input_sizes=reshaped_input_sizes,
|
||||
input_points=input_points)
|
||||
return dict(
|
||||
pixel_values=pixel_values,
|
||||
original_sizes=original_sizes,
|
||||
reshaped_input_sizes=reshaped_input_sizes,
|
||||
input_points=input_points,
|
||||
)
|
||||
|
||||
|
||||
# define output transform function
|
||||
@@ -44,9 +46,11 @@ config = transformers.SamConfig()
|
||||
config.vision_config.num_hidden_layers = 2
|
||||
|
||||
# register the BERT variants
|
||||
model_zoo.register(name='transformers_sam',
|
||||
model_fn=lambda: transformers.SamModel(config),
|
||||
data_gen_fn=data_gen,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(
|
||||
name="transformers_sam",
|
||||
model_fn=lambda: transformers.SamModel(config),
|
||||
data_gen_fn=data_gen,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
|
@@ -27,7 +27,7 @@ def data_gen_for_conditional_generation():
|
||||
# labels = tokenizer("Das Haus ist wunderbar.", return_tensors="pt").input_ids
|
||||
data = data_gen_for_encoder_only()
|
||||
labels = torch.Tensor([[644, 4598, 229, 19250, 5, 1, 644, 4598, 229, 19250, 5, 1, 229, 19250, 5, 1]]).long()
|
||||
data['labels'] = labels
|
||||
data["labels"] = labels
|
||||
return data
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@ def data_gen_for_t5_model():
|
||||
# decoder_input_ids = model._shift_right(input_ids)
|
||||
data = data_gen_for_encoder_only()
|
||||
decoder_input_ids = torch.Tensor([[0, 13959, 1566, 12, 2968, 10, 37, 629, 19, 1627, 5, 5, 19, 1627, 5, 5]]).long()
|
||||
data['decoder_input_ids'] = decoder_input_ids
|
||||
data["decoder_input_ids"] = decoder_input_ids
|
||||
return data
|
||||
|
||||
|
||||
@@ -55,21 +55,27 @@ config = transformers.T5Config(d_model=128, num_layers=2, dropout_rate=0, decode
|
||||
# transformers.T5Model,
|
||||
# transformers.T5ForConditionalGeneration,
|
||||
# transformers.T5EncoderModel,
|
||||
model_zoo.register(name='transformers_t5',
|
||||
model_fn=lambda: transformers.T5Model(config),
|
||||
data_gen_fn=data_gen_for_t5_model,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_t5_model,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_t5_for_conditional_generation',
|
||||
model_fn=lambda: transformers.T5ForConditionalGeneration(config),
|
||||
data_gen_fn=data_gen_for_conditional_generation,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_conditional_generation,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_t5_encoder_model',
|
||||
model_fn=lambda: transformers.T5EncoderModel(config),
|
||||
data_gen_fn=data_gen_for_encoder_only,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_encoder_only,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(
|
||||
name="transformers_t5",
|
||||
model_fn=lambda: transformers.T5Model(config),
|
||||
data_gen_fn=data_gen_for_t5_model,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_t5_model,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
name="transformers_t5_for_conditional_generation",
|
||||
model_fn=lambda: transformers.T5ForConditionalGeneration(config),
|
||||
data_gen_fn=data_gen_for_conditional_generation,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_conditional_generation,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
name="transformers_t5_encoder_model",
|
||||
model_fn=lambda: transformers.T5EncoderModel(config),
|
||||
data_gen_fn=data_gen_for_encoder_only,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_encoder_only,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
|
@@ -18,15 +18,15 @@ def data_gen():
|
||||
|
||||
def data_gen_for_image_classification():
|
||||
data = data_gen()
|
||||
data['labels'] = torch.tensor([0])
|
||||
data["labels"] = torch.tensor([0])
|
||||
return data
|
||||
|
||||
|
||||
def data_gen_for_masked_image_modeling():
|
||||
data = data_gen()
|
||||
num_patches = (config.image_size // config.patch_size)**2
|
||||
num_patches = (config.image_size // config.patch_size) ** 2
|
||||
bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool()
|
||||
data['bool_masked_pos'] = bool_masked_pos
|
||||
data["bool_masked_pos"] = bool_masked_pos
|
||||
return data
|
||||
|
||||
|
||||
@@ -42,23 +42,29 @@ loss_fn_for_masked_image_modeling = lambda x: x.loss
|
||||
# transformers.ViTModel,
|
||||
# transformers.ViTForMaskedImageModeling,
|
||||
# transformers.ViTForImageClassification,
|
||||
model_zoo.register(name='transformers_vit',
|
||||
model_fn=lambda: transformers.ViTModel(config),
|
||||
data_gen_fn=data_gen,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_vit_model,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(
|
||||
name="transformers_vit",
|
||||
model_fn=lambda: transformers.ViTModel(config),
|
||||
data_gen_fn=data_gen,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_vit_model,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
|
||||
model_zoo.register(name='transformers_vit_for_masked_image_modeling',
|
||||
model_fn=lambda: transformers.ViTForMaskedImageModeling(config),
|
||||
data_gen_fn=data_gen_for_masked_image_modeling,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_masked_image_modeling,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(
|
||||
name="transformers_vit_for_masked_image_modeling",
|
||||
model_fn=lambda: transformers.ViTForMaskedImageModeling(config),
|
||||
data_gen_fn=data_gen_for_masked_image_modeling,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_masked_image_modeling,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
|
||||
model_zoo.register(name='transformers_vit_for_image_classification',
|
||||
model_fn=lambda: transformers.ViTForImageClassification(config),
|
||||
data_gen_fn=data_gen_for_image_classification,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_image_classification,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(
|
||||
name="transformers_vit_for_image_classification",
|
||||
model_fn=lambda: transformers.ViTForImageClassification(config),
|
||||
data_gen_fn=data_gen_for_image_classification,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_image_classification,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
|
@@ -33,7 +33,7 @@ def data_gen_for_conditional_generation():
|
||||
# or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is
|
||||
# only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
data = data_gen()
|
||||
data['labels'] = torch.tensor([[0, 1]], dtype=torch.int64)
|
||||
data["labels"] = torch.tensor([[0, 1]], dtype=torch.int64)
|
||||
return data
|
||||
|
||||
|
||||
@@ -44,8 +44,8 @@ def data_gen_for_audio_classification():
|
||||
# `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
# `WhisperForAudioClassification` does not need `decoder_input_ids`
|
||||
data = data_gen()
|
||||
data.pop('decoder_input_ids')
|
||||
data['labels'] = torch.tensor([1], dtype=torch.int64)
|
||||
data.pop("decoder_input_ids")
|
||||
data["labels"] = torch.tensor([1], dtype=torch.int64)
|
||||
return data
|
||||
|
||||
|
||||
@@ -69,23 +69,29 @@ config = transformers.WhisperConfig(
|
||||
)
|
||||
|
||||
# register the Whisper variants
|
||||
model_zoo.register(name='transformers_whisper',
|
||||
model_fn=lambda: transformers.WhisperModel(config),
|
||||
data_gen_fn=data_gen,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(
|
||||
name="transformers_whisper",
|
||||
model_fn=lambda: transformers.WhisperModel(config),
|
||||
data_gen_fn=data_gen,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
|
||||
model_zoo.register(name='transformers_whisper_for_conditional_generation',
|
||||
model_fn=lambda: transformers.WhisperForConditionalGeneration(config),
|
||||
data_gen_fn=data_gen_for_conditional_generation,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_attr,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(
|
||||
name="transformers_whisper_for_conditional_generation",
|
||||
model_fn=lambda: transformers.WhisperForConditionalGeneration(config),
|
||||
data_gen_fn=data_gen_for_conditional_generation,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_attr,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
|
||||
model_zoo.register(name='transformers_whisper_for_audio_classification',
|
||||
model_fn=lambda: transformers.WhisperForAudioClassification(config),
|
||||
data_gen_fn=data_gen_for_audio_classification,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_attr,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(
|
||||
name="transformers_whisper_for_audio_classification",
|
||||
model_fn=lambda: transformers.WhisperForAudioClassification(config),
|
||||
data_gen_fn=data_gen_for_audio_classification,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_attr,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
|
Reference in New Issue
Block a user