mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-13 13:11:05 +00:00
[shardformer] support pipeline base vit model (#4284)
* Feature/vit support (#4182) * [shardformer] added tests * [shardformer] vit test finish and support * fix attention dropout * support base vit pipeline * support vit downstream model * fix vit shard test * modify hidden states return type --------- Co-authored-by: Kun Lin <81014421+klhhhhh@users.noreply.github.com>
This commit is contained in:
68
tests/kit/model_zoo/transformers/vit.py
Normal file
68
tests/kit/model_zoo/transformers/vit.py
Normal file
@@ -0,0 +1,68 @@
|
||||
import torch
|
||||
import transformers
|
||||
|
||||
from ..registry import ModelAttribute, model_zoo
|
||||
|
||||
# ===============================
|
||||
# Register single-sentence VIT
|
||||
# ===============================
|
||||
|
||||
config = transformers.ViTConfig(
|
||||
num_hidden_layers=4,
|
||||
# hidden_size=128,
|
||||
# intermediate_size=256,
|
||||
num_attention_heads=4)
|
||||
|
||||
|
||||
# define data gen function
|
||||
def data_gen():
|
||||
pixel_values = torch.randn(1, 3, 224, 224)
|
||||
return dict(pixel_values=pixel_values)
|
||||
|
||||
|
||||
def data_gen_for_image_classification():
|
||||
data = data_gen()
|
||||
data['labels'] = torch.tensor([0])
|
||||
return data
|
||||
|
||||
|
||||
def data_gen_for_masked_image_modeling():
|
||||
data = data_gen()
|
||||
num_patches = (config.image_size // config.patch_size)**2
|
||||
bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool()
|
||||
data['bool_masked_pos'] = bool_masked_pos
|
||||
return data
|
||||
|
||||
|
||||
# define output transform function
|
||||
output_transform_fn = lambda x: x
|
||||
|
||||
# function to get the loss
|
||||
loss_fn_for_vit_model = lambda x: x.pooler_output.mean()
|
||||
loss_fn_for_image_classification = lambda x: x.logits.mean()
|
||||
loss_fn_for_masked_image_modeling = lambda x: x.loss
|
||||
|
||||
# register the following models
|
||||
# transformers.ViTModel,
|
||||
# transformers.ViTForMaskedImageModeling,
|
||||
# transformers.ViTForImageClassification,
|
||||
model_zoo.register(name='transformers_vit',
|
||||
model_fn=lambda: transformers.ViTModel(config),
|
||||
data_gen_fn=data_gen,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_vit_model,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
|
||||
model_zoo.register(name='transformers_vit_for_masked_image_modeling',
|
||||
model_fn=lambda: transformers.ViTForMaskedImageModeling(config),
|
||||
data_gen_fn=data_gen_for_masked_image_modeling,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_masked_image_modeling,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
|
||||
model_zoo.register(name='transformers_vit_for_image_classification',
|
||||
model_fn=lambda: transformers.ViTForImageClassification(config),
|
||||
data_gen_fn=data_gen_for_image_classification,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_image_classification,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
Reference in New Issue
Block a user