mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-15 06:00:07 +00:00
Migrated project
This commit is contained in:
3
colossalai/nn/model/__init__.py
Normal file
3
colossalai/nn/model/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .base_model import BaseModel
|
||||
from .vanilla_resnet import VanillaResNet
|
||||
from .vision_transformer import *
|
38
colossalai/nn/model/base_model.py
Normal file
38
colossalai/nn/model/base_model.py
Normal file
@@ -0,0 +1,38 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.builder import build_layer
|
||||
|
||||
|
||||
class BaseModel(nn.Module, ABC):
|
||||
|
||||
def __init__(self):
|
||||
super(BaseModel, self).__init__()
|
||||
self.layers = nn.ModuleList()
|
||||
self.layers_cfg = []
|
||||
|
||||
def build_from_cfg(self, start=None, end=None):
|
||||
assert hasattr(self, 'layers_cfg'), 'Cannot find attribute layers_cfg from the module, please check the ' \
|
||||
'spelling and if you have initialized this variable'
|
||||
if start is None:
|
||||
start = 0
|
||||
if end is None:
|
||||
end = len(self.layers_cfg)
|
||||
for cfg in self.layers_cfg[start: end]:
|
||||
layer = build_layer(cfg)
|
||||
self.layers.append(layer)
|
||||
|
||||
@abstractmethod
|
||||
def init_weights(self):
|
||||
pass
|
||||
|
||||
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
|
||||
keep_vars=False):
|
||||
|
||||
"""Use this function to override the state dict for
|
||||
saving checkpoints."""
|
||||
return self.state_dict(destination, prefix, keep_vars)
|
3
colossalai/nn/model/vanilla_resnet/__init__.py
Normal file
3
colossalai/nn/model/vanilla_resnet/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .resnet import VanillaResNet
|
||||
|
||||
__all__ = ['VanillaResNet']
|
163
colossalai/nn/model/vanilla_resnet/resnet.py
Normal file
163
colossalai/nn/model/vanilla_resnet/resnet.py
Normal file
@@ -0,0 +1,163 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
|
||||
from colossalai.registry import LAYERS
|
||||
from colossalai.registry import MODELS
|
||||
from ..base_model import BaseModel
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
class VanillaResNet(BaseModel):
|
||||
"""ResNet from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_cls: int,
|
||||
block_type: str,
|
||||
layers: List[int],
|
||||
norm_layer_type: str = 'BatchNorm2d',
|
||||
in_channels: int = 3,
|
||||
groups: int = 1,
|
||||
width_per_group: int = 64,
|
||||
zero_init_residual: bool = False,
|
||||
replace_stride_with_dilation: Optional[List[bool]] = None,
|
||||
dilations=(1, 1, 1, 1)
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.inplanes = 64
|
||||
self.zero_init_residual = zero_init_residual
|
||||
self.blocks = layers
|
||||
self.block_expansion = LAYERS.get_module(block_type).expansion
|
||||
self.dilations = dilations
|
||||
self.reslayer_common_cfg = dict(
|
||||
type='ResLayer',
|
||||
block_type=block_type,
|
||||
norm_layer_type=norm_layer_type,
|
||||
groups=groups,
|
||||
base_width=width_per_group
|
||||
)
|
||||
|
||||
if replace_stride_with_dilation is None:
|
||||
# each element in the tuple indicates if we should replace
|
||||
# the 2x2 stride with a dilated convolution instead
|
||||
replace_stride_with_dilation = [False, False, False]
|
||||
|
||||
if len(replace_stride_with_dilation) != 3:
|
||||
raise ValueError("replace_stride_with_dilation should be None "
|
||||
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
|
||||
|
||||
self.layers_cfg = [
|
||||
# conv1
|
||||
dict(type='Conv2d',
|
||||
in_channels=in_channels,
|
||||
out_channels=self.inplanes,
|
||||
kernel_size=7,
|
||||
stride=2,
|
||||
padding=3,
|
||||
bias=False),
|
||||
# bn1
|
||||
dict(
|
||||
type=norm_layer_type,
|
||||
num_features=self.inplanes
|
||||
),
|
||||
# relu
|
||||
dict(
|
||||
type='ReLU',
|
||||
inplace=True
|
||||
),
|
||||
# maxpool
|
||||
dict(
|
||||
type='MaxPool2d',
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1
|
||||
),
|
||||
# layer 1
|
||||
dict(
|
||||
inplanes=self.inplanes,
|
||||
planes=64,
|
||||
blocks=self.blocks[0],
|
||||
dilation=self.dilations[0],
|
||||
**self.reslayer_common_cfg
|
||||
),
|
||||
# layer 2
|
||||
dict(
|
||||
inplanes=64 * self.block_expansion,
|
||||
planes=128,
|
||||
blocks=self.blocks[1],
|
||||
stride=2,
|
||||
dilate=replace_stride_with_dilation[0],
|
||||
dilation=self.dilations[1],
|
||||
**self.reslayer_common_cfg
|
||||
),
|
||||
# layer 3
|
||||
dict(
|
||||
inplanes=128 * self.block_expansion,
|
||||
planes=256,
|
||||
blocks=layers[2],
|
||||
stride=2,
|
||||
dilate=replace_stride_with_dilation[1],
|
||||
dilation=self.dilations[2],
|
||||
**self.reslayer_common_cfg
|
||||
),
|
||||
# layer 4
|
||||
dict(
|
||||
inplanes=256 * self.block_expansion,
|
||||
planes=512,
|
||||
blocks=layers[3], stride=2,
|
||||
dilate=replace_stride_with_dilation[2],
|
||||
dilation=self.dilations[3],
|
||||
**self.reslayer_common_cfg
|
||||
),
|
||||
# avg pool
|
||||
dict(
|
||||
type='AdaptiveAvgPool2d',
|
||||
output_size=(1, 1)
|
||||
),
|
||||
# flatten
|
||||
dict(
|
||||
type='LambdaWrapper',
|
||||
func=lambda mod, x: torch.flatten(x, 1)
|
||||
),
|
||||
# linear
|
||||
dict(
|
||||
type='Linear',
|
||||
in_features=512 * self.block_expansion,
|
||||
out_features=num_cls
|
||||
)
|
||||
]
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
return x,
|
||||
|
||||
def init_weights(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(
|
||||
m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
# Zero-initialize the last BN in each residual branch,
|
||||
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
|
||||
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
|
||||
if self.zero_init_residual:
|
||||
for m in self.modules():
|
||||
if isinstance(m, LAYERS.get_module('ResNetBottleneck')):
|
||||
# type: ignore[arg-type]
|
||||
nn.init.constant_(m.bn3.weight, 0)
|
||||
elif isinstance(m, LAYERS.get_module('ResNetBasicBlock')):
|
||||
# type: ignore[arg-type]
|
||||
nn.init.constant_(m.bn2.weight, 0)
|
3
colossalai/nn/model/vision_transformer/__init__.py
Normal file
3
colossalai/nn/model/vision_transformer/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .vision_transformer import VisionTransformerFromConfig
|
||||
|
||||
__all__ = ['VisionTransformerFromConfig']
|
87
colossalai/nn/model/vision_transformer/vision_transformer.py
Normal file
87
colossalai/nn/model/vision_transformer/vision_transformer.py
Normal file
@@ -0,0 +1,87 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.registry import MODELS
|
||||
from ..base_model import BaseModel
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
class VisionTransformerFromConfig(BaseModel):
|
||||
"""Vision Transformer from
|
||||
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/pdf/2010.11929>`_.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embedding_cfg: dict,
|
||||
norm_cfg: dict,
|
||||
block_cfg: dict,
|
||||
head_cfg: dict,
|
||||
token_fusion_cfg: dict = None,
|
||||
embed_dim=768,
|
||||
depth=12,
|
||||
drop_path_rate=0.,
|
||||
tensor_splitting_cfg: dict = None):
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.num_tokens = 1
|
||||
self.tensor_splitting_cfg = tensor_splitting_cfg
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)
|
||||
] # stochastic depth decay rule
|
||||
if token_fusion_cfg is None:
|
||||
token_fusion_cfg = []
|
||||
else:
|
||||
token_fusion_cfg = [token_fusion_cfg]
|
||||
|
||||
self.layers_cfg = [
|
||||
embedding_cfg,
|
||||
|
||||
# input tensor splitting
|
||||
*self._generate_tensor_splitting_cfg(),
|
||||
*token_fusion_cfg,
|
||||
|
||||
# blocks
|
||||
*self._generate_block_cfg(
|
||||
dpr=dpr, block_cfg=block_cfg, depth=depth),
|
||||
|
||||
# norm
|
||||
norm_cfg,
|
||||
|
||||
# head
|
||||
head_cfg
|
||||
]
|
||||
|
||||
def _fuse_tokens(self, x):
|
||||
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
|
||||
x = torch.cat((cls_token, x), dim=1)
|
||||
return x
|
||||
|
||||
def _generate_block_cfg(self, dpr, depth, block_cfg):
|
||||
blocks_cfg = []
|
||||
|
||||
for i in range(depth):
|
||||
_cfg = block_cfg.copy()
|
||||
_cfg['droppath_cfg']['drop_path'] = dpr[i]
|
||||
blocks_cfg.append(_cfg)
|
||||
|
||||
return blocks_cfg
|
||||
|
||||
def _generate_tensor_splitting_cfg(self):
|
||||
if self.tensor_splitting_cfg:
|
||||
return [self.tensor_splitting_cfg]
|
||||
else:
|
||||
return []
|
||||
|
||||
def forward(self, x): # [512, 3, 32, 32]
|
||||
for layer in self.layers:
|
||||
if isinstance(x, tuple):
|
||||
x = layer(*x)
|
||||
else:
|
||||
x = layer(x)
|
||||
return x # [256, 5]
|
||||
|
||||
def init_weights(self):
|
||||
# TODO: add init weights
|
||||
pass
|
Reference in New Issue
Block a user