mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-10 21:40:02 +00:00
[testing] add beit model for unit testings (#2196)
* [testing] add beit model * [beit] fix bugs * [beit] fix bugs * [testing] fix bugs
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
from . import (
|
||||
beit,
|
||||
bert,
|
||||
gpt2,
|
||||
hanging_param_model,
|
||||
@@ -14,5 +15,5 @@ from . import albert # isort:skip
|
||||
|
||||
__all__ = [
|
||||
'bert', 'gpt2', 'hanging_param_model', 'inline_op_model', 'nested_model', 'repeated_computed_layers', 'resnet',
|
||||
'simple_net', 'run_fwd_bwd', 'albert'
|
||||
'simple_net', 'run_fwd_bwd', 'albert', 'beit'
|
||||
]
|
||||
|
42
tests/components_to_test/beit.py
Normal file
42
tests/components_to_test/beit.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import torch
|
||||
from timm.models.beit import Beit
|
||||
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
|
||||
from .registry import non_distributed_component_funcs
|
||||
from .utils.dummy_data_generator import DummyDataGenerator
|
||||
|
||||
|
||||
class DummyDataLoader(DummyDataGenerator):
|
||||
img_size = 64
|
||||
num_channel = 3
|
||||
num_class = 10
|
||||
batch_size = 4
|
||||
|
||||
def generate(self):
|
||||
data = torch.randn((DummyDataLoader.batch_size, DummyDataLoader.num_channel, DummyDataLoader.img_size,
|
||||
DummyDataLoader.img_size),
|
||||
device=get_current_device())
|
||||
label = torch.randint(low=0,
|
||||
high=DummyDataLoader.num_class,
|
||||
size=(DummyDataLoader.batch_size,),
|
||||
device=get_current_device())
|
||||
return data, label
|
||||
|
||||
|
||||
@non_distributed_component_funcs.register(name='beit')
|
||||
def get_training_components():
|
||||
|
||||
def model_buider(checkpoint=False):
|
||||
model = Beit(img_size=DummyDataLoader.img_size,
|
||||
num_classes=DummyDataLoader.num_class,
|
||||
embed_dim=32,
|
||||
depth=2,
|
||||
num_heads=4)
|
||||
return model
|
||||
|
||||
trainloader = DummyDataLoader()
|
||||
testloader = DummyDataLoader()
|
||||
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
return model_buider, trainloader, testloader, torch.optim.Adam, criterion
|
Reference in New Issue
Block a user