Migrated project

This commit is contained in:
zbian
2021-10-28 18:21:23 +02:00
parent 2ebaefc542
commit 404ecbdcc6
409 changed files with 35853 additions and 0 deletions

View File

@@ -0,0 +1,3 @@
from .base_model import BaseModel
from .vanilla_resnet import VanillaResNet
from .vision_transformer import *

View File

@@ -0,0 +1,38 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from abc import ABC, abstractmethod
import torch.nn as nn
from colossalai.builder import build_layer
class BaseModel(nn.Module, ABC):
def __init__(self):
super(BaseModel, self).__init__()
self.layers = nn.ModuleList()
self.layers_cfg = []
def build_from_cfg(self, start=None, end=None):
assert hasattr(self, 'layers_cfg'), 'Cannot find attribute layers_cfg from the module, please check the ' \
'spelling and if you have initialized this variable'
if start is None:
start = 0
if end is None:
end = len(self.layers_cfg)
for cfg in self.layers_cfg[start: end]:
layer = build_layer(cfg)
self.layers.append(layer)
@abstractmethod
def init_weights(self):
pass
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
"""Use this function to override the state dict for
saving checkpoints."""
return self.state_dict(destination, prefix, keep_vars)

View File

@@ -0,0 +1,3 @@
from .resnet import VanillaResNet
__all__ = ['VanillaResNet']

View File

@@ -0,0 +1,163 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import List, Optional
import torch
import torch.nn as nn
from torch import Tensor
from colossalai.registry import LAYERS
from colossalai.registry import MODELS
from ..base_model import BaseModel
@MODELS.register_module
class VanillaResNet(BaseModel):
"""ResNet from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
"""
def __init__(
self,
num_cls: int,
block_type: str,
layers: List[int],
norm_layer_type: str = 'BatchNorm2d',
in_channels: int = 3,
groups: int = 1,
width_per_group: int = 64,
zero_init_residual: bool = False,
replace_stride_with_dilation: Optional[List[bool]] = None,
dilations=(1, 1, 1, 1)
) -> None:
super().__init__()
self.inplanes = 64
self.zero_init_residual = zero_init_residual
self.blocks = layers
self.block_expansion = LAYERS.get_module(block_type).expansion
self.dilations = dilations
self.reslayer_common_cfg = dict(
type='ResLayer',
block_type=block_type,
norm_layer_type=norm_layer_type,
groups=groups,
base_width=width_per_group
)
if replace_stride_with_dilation is None:
# each element in the tuple indicates if we should replace
# the 2x2 stride with a dilated convolution instead
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError("replace_stride_with_dilation should be None "
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
self.layers_cfg = [
# conv1
dict(type='Conv2d',
in_channels=in_channels,
out_channels=self.inplanes,
kernel_size=7,
stride=2,
padding=3,
bias=False),
# bn1
dict(
type=norm_layer_type,
num_features=self.inplanes
),
# relu
dict(
type='ReLU',
inplace=True
),
# maxpool
dict(
type='MaxPool2d',
kernel_size=3,
stride=2,
padding=1
),
# layer 1
dict(
inplanes=self.inplanes,
planes=64,
blocks=self.blocks[0],
dilation=self.dilations[0],
**self.reslayer_common_cfg
),
# layer 2
dict(
inplanes=64 * self.block_expansion,
planes=128,
blocks=self.blocks[1],
stride=2,
dilate=replace_stride_with_dilation[0],
dilation=self.dilations[1],
**self.reslayer_common_cfg
),
# layer 3
dict(
inplanes=128 * self.block_expansion,
planes=256,
blocks=layers[2],
stride=2,
dilate=replace_stride_with_dilation[1],
dilation=self.dilations[2],
**self.reslayer_common_cfg
),
# layer 4
dict(
inplanes=256 * self.block_expansion,
planes=512,
blocks=layers[3], stride=2,
dilate=replace_stride_with_dilation[2],
dilation=self.dilations[3],
**self.reslayer_common_cfg
),
# avg pool
dict(
type='AdaptiveAvgPool2d',
output_size=(1, 1)
),
# flatten
dict(
type='LambdaWrapper',
func=lambda mod, x: torch.flatten(x, 1)
),
# linear
dict(
type='Linear',
in_features=512 * self.block_expansion,
out_features=num_cls
)
]
def forward(self, x: Tensor):
for layer in self.layers:
x = layer(x)
return x,
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(
m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
# Zero-initialize the last BN in each residual branch,
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
if self.zero_init_residual:
for m in self.modules():
if isinstance(m, LAYERS.get_module('ResNetBottleneck')):
# type: ignore[arg-type]
nn.init.constant_(m.bn3.weight, 0)
elif isinstance(m, LAYERS.get_module('ResNetBasicBlock')):
# type: ignore[arg-type]
nn.init.constant_(m.bn2.weight, 0)

View File

@@ -0,0 +1,3 @@
from .vision_transformer import VisionTransformerFromConfig
__all__ = ['VisionTransformerFromConfig']

View File

@@ -0,0 +1,87 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch
from colossalai.registry import MODELS
from ..base_model import BaseModel
@MODELS.register_module
class VisionTransformerFromConfig(BaseModel):
"""Vision Transformer from
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/pdf/2010.11929>`_.
"""
def __init__(self,
embedding_cfg: dict,
norm_cfg: dict,
block_cfg: dict,
head_cfg: dict,
token_fusion_cfg: dict = None,
embed_dim=768,
depth=12,
drop_path_rate=0.,
tensor_splitting_cfg: dict = None):
super().__init__()
self.embed_dim = embed_dim
self.num_tokens = 1
self.tensor_splitting_cfg = tensor_splitting_cfg
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)
] # stochastic depth decay rule
if token_fusion_cfg is None:
token_fusion_cfg = []
else:
token_fusion_cfg = [token_fusion_cfg]
self.layers_cfg = [
embedding_cfg,
# input tensor splitting
*self._generate_tensor_splitting_cfg(),
*token_fusion_cfg,
# blocks
*self._generate_block_cfg(
dpr=dpr, block_cfg=block_cfg, depth=depth),
# norm
norm_cfg,
# head
head_cfg
]
def _fuse_tokens(self, x):
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_token, x), dim=1)
return x
def _generate_block_cfg(self, dpr, depth, block_cfg):
blocks_cfg = []
for i in range(depth):
_cfg = block_cfg.copy()
_cfg['droppath_cfg']['drop_path'] = dpr[i]
blocks_cfg.append(_cfg)
return blocks_cfg
def _generate_tensor_splitting_cfg(self):
if self.tensor_splitting_cfg:
return [self.tensor_splitting_cfg]
else:
return []
def forward(self, x): # [512, 3, 32, 32]
for layer in self.layers:
if isinstance(x, tuple):
x = layer(*x)
else:
x = layer(x)
return x # [256, 5]
def init_weights(self):
# TODO: add init weights
pass