mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 11:32:10 +00:00
Migrated project
This commit is contained in:
2
model_zoo/__init__.py
Normal file
2
model_zoo/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .vit import *
|
||||
from .mlp_mixer import *
|
0
model_zoo/bert/parallel_1d/.init
Normal file
0
model_zoo/bert/parallel_1d/.init
Normal file
0
model_zoo/bert/parallel_2d/.init
Normal file
0
model_zoo/bert/parallel_2d/.init
Normal file
0
model_zoo/bert/parallel_2p5d/.init
Normal file
0
model_zoo/bert/parallel_2p5d/.init
Normal file
0
model_zoo/bert/parallel_3d/.init
Normal file
0
model_zoo/bert/parallel_3d/.init
Normal file
1
model_zoo/mlp_mixer/__init__.py
Normal file
1
model_zoo/mlp_mixer/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .parallel_3d import *
|
0
model_zoo/mlp_mixer/parallel_1d/.init
Normal file
0
model_zoo/mlp_mixer/parallel_1d/.init
Normal file
0
model_zoo/mlp_mixer/parallel_2d/.init
Normal file
0
model_zoo/mlp_mixer/parallel_2d/.init
Normal file
0
model_zoo/mlp_mixer/parallel_2p5d/.init
Normal file
0
model_zoo/mlp_mixer/parallel_2p5d/.init
Normal file
1
model_zoo/mlp_mixer/parallel_3d/__init__.py
Normal file
1
model_zoo/mlp_mixer/parallel_3d/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .mlp_mixer import *
|
63
model_zoo/mlp_mixer/parallel_3d/mlp_mixer.py
Normal file
63
model_zoo/mlp_mixer/parallel_3d/mlp_mixer.py
Normal file
@@ -0,0 +1,63 @@
|
||||
# modified from https://github.com/lucidrains/mlp-mixer-pytorch/blob/main/mlp_mixer_pytorch/mlp_mixer_pytorch.py
|
||||
from functools import partial
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.registry import MODELS
|
||||
from torch import nn
|
||||
from colossalai import nn as col_nn
|
||||
from colossalai.nn.layer.parallel_3d._utils import get_depth_from_env
|
||||
from einops.layers.torch import Rearrange, Reduce
|
||||
|
||||
__all__ = [
|
||||
'MLPMixer',
|
||||
]
|
||||
|
||||
|
||||
class PreNormResidual(nn.Module):
|
||||
def __init__(self, dim, fn, depth_3d):
|
||||
super().__init__()
|
||||
self.fn = fn
|
||||
self.norm = col_nn.LayerNorm3D(
|
||||
dim, depth_3d, ParallelMode.PARALLEL_3D_INPUT, ParallelMode.PARALLEL_3D_WEIGHT)
|
||||
|
||||
def forward(self, x):
|
||||
return self.fn(self.norm(x)) + x
|
||||
|
||||
|
||||
def FeedForward(dim, depth_3d, expansion_factor=4, dropout=0., dense=None):
|
||||
if dense is None:
|
||||
dense = partial(col_nn.Linear3D, depth=depth_3d, input_parallel_mode=ParallelMode.PARALLEL_3D_INPUT,
|
||||
weight_parallel_mode=ParallelMode.PARALLEL_3D_WEIGHT)
|
||||
return nn.Sequential(
|
||||
dense(dim, dim * expansion_factor),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
dense(dim * expansion_factor, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def MLPMixer(image_size, channels, patch_size, dim, depth, num_classes, expansion_factor=4, dropout=0.):
|
||||
assert (image_size % patch_size) == 0, 'image must be divisible by patch size'
|
||||
num_patches = (image_size // patch_size) ** 2
|
||||
depth_3d = get_depth_from_env()
|
||||
linear = partial(col_nn.Linear3D, depth=depth_3d, input_parallel_mode=ParallelMode.PARALLEL_3D_INPUT,
|
||||
weight_parallel_mode=ParallelMode.PARALLEL_3D_WEIGHT)
|
||||
norm_layer = partial(col_nn.LayerNorm3D, depth=depth_3d, input_parallel_mode=ParallelMode.PARALLEL_3D_INPUT,
|
||||
weight_parallel_mode=ParallelMode.PARALLEL_3D_WEIGHT)
|
||||
chan_first, chan_last = partial(nn.Conv1d, kernel_size=1), linear
|
||||
|
||||
return nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)',
|
||||
p1=patch_size, p2=patch_size),
|
||||
linear((patch_size ** 2) * channels, dim),
|
||||
*[nn.Sequential(
|
||||
PreNormResidual(dim, FeedForward(
|
||||
num_patches, expansion_factor, dropout, chan_first)),
|
||||
PreNormResidual(dim, FeedForward(
|
||||
dim, expansion_factor, dropout, chan_last))
|
||||
) for _ in range(depth)],
|
||||
norm_layer(dim),
|
||||
Reduce('b n c -> b c', 'mean'),
|
||||
linear(dim, num_classes)
|
||||
)
|
2
model_zoo/vit/__init__.py
Normal file
2
model_zoo/vit/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .parallel_2d import *
|
||||
from .parallel_3d import *
|
0
model_zoo/vit/parallel_1d/.init
Normal file
0
model_zoo/vit/parallel_1d/.init
Normal file
1
model_zoo/vit/parallel_2d/__init__.py
Normal file
1
model_zoo/vit/parallel_2d/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .vit import *
|
219
model_zoo/vit/parallel_2d/vit.py
Normal file
219
model_zoo/vit/parallel_2d/vit.py
Normal file
@@ -0,0 +1,219 @@
|
||||
from colossalai.context import ParallelMode, seed
|
||||
from colossalai import nn as clsl_nn
|
||||
from colossalai.registry import MODELS
|
||||
from torch import nn
|
||||
import torch
|
||||
|
||||
|
||||
__all__ = [
|
||||
'VisionTransformer2D',
|
||||
'vit_tiny_2d_patch4_32',
|
||||
'vit_tiny_2d_patch16_224',
|
||||
'vit_tiny_2d_patch16_384',
|
||||
'vit_small_2d_patch16_224',
|
||||
'vit_small_2d_patch16_384',
|
||||
'vit_small_2d_patch32_224',
|
||||
'vit_small_2d_patch32_384',
|
||||
'vit_base_2d_patch16_224',
|
||||
'vit_base_2d_patch16_384',
|
||||
'vit_base_2d_patch32_224',
|
||||
'vit_base_2d_patch32_384',
|
||||
'vit_large_2d_patch16_224',
|
||||
'vit_large_2d_patch16_384',
|
||||
'vit_large_2d_patch32_224',
|
||||
'vit_large_2d_patch32_384',
|
||||
]
|
||||
|
||||
|
||||
class ViTBlock2D(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: int = 4,
|
||||
drop: float = 0.,
|
||||
attn_drop: float = 0.,
|
||||
drop_path: float = 0.,
|
||||
act_layer: str = 'gelu'):
|
||||
super().__init__()
|
||||
self.norm1 = clsl_nn.LayerNorm2D(dim, eps=1e-6)
|
||||
self.attn = clsl_nn.ViTSelfAttention2D(dim, num_heads, attn_drop, drop)
|
||||
self.drop_path = clsl_nn.VanillaViTDropPath(drop_path) if drop_path > 0. \
|
||||
else nn.Identity()
|
||||
self.norm2 = clsl_nn.LayerNorm2D(dim, eps=1e-6)
|
||||
self.mlp = clsl_nn.ViTMLP2D(dim, mlp_ratio, act_layer, drop)
|
||||
|
||||
def forward(self, x):
|
||||
y = self.attn(self.norm1(x))
|
||||
with seed(ParallelMode.TENSOR):
|
||||
x = x + self.drop_path(y)
|
||||
y = self.mlp(self.norm2(x))
|
||||
with seed(ParallelMode.TENSOR):
|
||||
x = x + self.drop_path(y)
|
||||
return x
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
class VisionTransformer2D(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
img_size: int = 224,
|
||||
patch_size: int = 16,
|
||||
in_chans: int = 3,
|
||||
num_classes: int = 1000,
|
||||
embed_dim: int = 768,
|
||||
depth: int = 12,
|
||||
num_heads: int = 12,
|
||||
mlp_ratio: int = 4,
|
||||
drop_rate: float = 0.,
|
||||
attn_drop_rate: float = 0.,
|
||||
drop_path_rate: float = 0.,
|
||||
act_layer: str = 'gelu'):
|
||||
super().__init__()
|
||||
self.num_classes = num_classes
|
||||
self.num_features = self.embed_dim = embed_dim
|
||||
|
||||
self.patch_embed = clsl_nn.ViTPatchEmbedding2D(
|
||||
img_size, patch_size, embed_dim, in_chans
|
||||
)
|
||||
|
||||
self.splitter = clsl_nn.ViTInputSplitter2D()
|
||||
|
||||
self.token_fuser = clsl_nn.ViTTokenFuser2D(
|
||||
img_size, patch_size, embed_dim, drop_rate
|
||||
)
|
||||
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
|
||||
self.blocks = nn.Sequential(*[
|
||||
ViTBlock2D(embed_dim, num_heads, mlp_ratio, drop_rate,
|
||||
attn_drop_rate, dpr[i], act_layer)
|
||||
for i in range(depth)
|
||||
])
|
||||
|
||||
self.norm = clsl_nn.LayerNorm2D(embed_dim, eps=1e-6)
|
||||
self.head = clsl_nn.ViTHead2D(self.num_features, num_classes) if num_classes > 0 \
|
||||
else nn.Identity()
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def init_weights(self):
|
||||
pass
|
||||
|
||||
def forward(self, x):
|
||||
x = self.patch_embed(x)
|
||||
x = self.splitter(x)
|
||||
x = self.token_fuser(x)
|
||||
x = self.blocks(x)
|
||||
x = self.norm(x)
|
||||
x = self.head(x)
|
||||
return x
|
||||
|
||||
|
||||
def _create_vit_model(**model_kwargs):
|
||||
model = VisionTransformer2D(**model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_tiny_2d_patch4_32(**kwargs):
|
||||
model_kwargs = dict(img_size=32, patch_size=4, embed_dim=512,
|
||||
depth=6, num_heads=8, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_tiny_2d_patch16_224(**kwargs):
|
||||
model_kwargs = dict(patch_size=16, embed_dim=192,
|
||||
depth=12, num_heads=3, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_tiny_2d_patch16_384(**kwargs):
|
||||
model_kwargs = dict(img_size=384, patch_size=16, embed_dim=192,
|
||||
depth=12, num_heads=3, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_small_2d_patch16_224(**kwargs):
|
||||
model_kwargs = dict(patch_size=16, embed_dim=384,
|
||||
depth=12, num_heads=6, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_small_2d_patch16_384(**kwargs):
|
||||
model_kwargs = dict(img_size=384, patch_size=16, embed_dim=384,
|
||||
depth=12, num_heads=6, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_small_2d_patch32_224(**kwargs):
|
||||
model_kwargs = dict(patch_size=32, embed_dim=384,
|
||||
depth=12, num_heads=6, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_small_2d_patch32_384(**kwargs):
|
||||
model_kwargs = dict(img_size=384, patch_size=32, embed_dim=384,
|
||||
depth=12, num_heads=6, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_base_2d_patch16_224(**kwargs):
|
||||
model_kwargs = dict(patch_size=16, embed_dim=768,
|
||||
depth=12, num_heads=12, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_base_2d_patch16_384(**kwargs):
|
||||
model_kwargs = dict(img_size=384, patch_size=16, embed_dim=768,
|
||||
depth=12, num_heads=12, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_base_2d_patch32_224(**kwargs):
|
||||
model_kwargs = dict(patch_size=32, embed_dim=768,
|
||||
depth=12, num_heads=12, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_base_2d_patch32_384(**kwargs):
|
||||
model_kwargs = dict(img_size=384, patch_size=32, embed_dim=768,
|
||||
depth=12, num_heads=12, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_large_2d_patch16_224(**kwargs):
|
||||
model_kwargs = dict(patch_size=16, embed_dim=1024,
|
||||
depth=24, num_heads=16, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_large_2d_patch16_384(**kwargs):
|
||||
model_kwargs = dict(img_size=384, patch_size=16, embed_dim=1024,
|
||||
depth=24, num_heads=16, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_large_2d_patch32_224(**kwargs):
|
||||
model_kwargs = dict(patch_size=32, embed_dim=1024,
|
||||
depth=24, num_heads=16, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_large_2d_patch32_384(**kwargs):
|
||||
model_kwargs = dict(img_size=384, patch_size=32, embed_dim=1024,
|
||||
depth=24, num_heads=16, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
0
model_zoo/vit/parallel_2p5d/.init
Normal file
0
model_zoo/vit/parallel_2p5d/.init
Normal file
1
model_zoo/vit/parallel_3d/__init__.py
Normal file
1
model_zoo/vit/parallel_3d/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .vit import *
|
209
model_zoo/vit/parallel_3d/vit.py
Normal file
209
model_zoo/vit/parallel_3d/vit.py
Normal file
@@ -0,0 +1,209 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from colossalai import nn as col_nn
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.registry import MODELS
|
||||
|
||||
__all__ = [
|
||||
'VisionTransformer3D',
|
||||
'vit_tiny_3d_patch4_32',
|
||||
'vit_tiny_3d_patch16_224',
|
||||
'vit_tiny_3d_patch16_384',
|
||||
'vit_small_3d_patch16_224',
|
||||
'vit_small_3d_patch16_384',
|
||||
'vit_small_3d_patch32_224',
|
||||
'vit_small_3d_patch32_384',
|
||||
'vit_base_3d_patch16_224',
|
||||
'vit_base_3d_patch16_384',
|
||||
'vit_base_3d_patch32_224',
|
||||
'vit_base_3d_patch32_384',
|
||||
'vit_large_3d_patch16_224',
|
||||
'vit_large_3d_patch16_384',
|
||||
'vit_large_3d_patch32_224',
|
||||
'vit_large_3d_patch32_384',
|
||||
]
|
||||
|
||||
|
||||
class ViTBlock3D(nn.Module):
|
||||
def __init__(self,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
hidden_dim: int,
|
||||
drop: float = 0.,
|
||||
attn_drop: float = 0.,
|
||||
drop_path: float = 0.):
|
||||
super().__init__()
|
||||
self.norm1 = col_nn.LayerNorm3D(
|
||||
dim, ParallelMode.PARALLEL_3D_INPUT, ParallelMode.PARALLEL_3D_WEIGHT, eps=1e-6)
|
||||
self.attn = col_nn.ViTSelfAttention3D(dim, num_heads, attn_drop, drop)
|
||||
self.drop_path = col_nn.VanillaViTDropPath(
|
||||
drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.norm2 = col_nn.LayerNorm3D(dim, ParallelMode.PARALLEL_3D_INPUT, ParallelMode.PARALLEL_3D_WEIGHT, eps=1e-6)
|
||||
self.mlp = col_nn.ViTMLP3D(hidden_dim, 1, drop, 'gelu')
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.drop_path(self.attn(self.norm1(x)))
|
||||
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
class VisionTransformer3D(nn.Module):
|
||||
def __init__(self,
|
||||
img_size: int = 224,
|
||||
patch_size: int = 16,
|
||||
in_chans: int = 3,
|
||||
num_classes: int = 1000,
|
||||
depth: int = 12,
|
||||
num_heads: int = 12,
|
||||
embed_dim: int = 768,
|
||||
hidden_dim: int = 3072,
|
||||
drop_rate: float = 0.,
|
||||
attn_drop_rate: float = 0.,
|
||||
drop_path_rate: float = 0.):
|
||||
super().__init__()
|
||||
self.num_classes = num_classes
|
||||
self.num_features = self.embed_dim = embed_dim
|
||||
|
||||
self.patch_embed = col_nn.ViTPatchEmbedding3D(
|
||||
img_size,
|
||||
patch_size,
|
||||
in_chans,
|
||||
embed_dim,
|
||||
drop_rate,
|
||||
)
|
||||
|
||||
# stochastic depth decay rule
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
|
||||
self.blocks = nn.Sequential(*[
|
||||
ViTBlock3D(embed_dim, num_heads, hidden_dim,
|
||||
drop_rate, attn_drop_rate, dpr[i])
|
||||
for i in range(depth)
|
||||
])
|
||||
|
||||
self.norm = col_nn.LayerNorm3D(embed_dim, ParallelMode.PARALLEL_3D_INPUT,
|
||||
ParallelMode.PARALLEL_3D_WEIGHT)
|
||||
|
||||
self.head = col_nn.ViTHead3D(hidden_dim, num_classes)
|
||||
self.init_weights()
|
||||
|
||||
def init_weights(self):
|
||||
pass
|
||||
|
||||
def forward(self, x):
|
||||
x = self.patch_embed(x)
|
||||
x = self.blocks(x)
|
||||
x = self.norm(x)
|
||||
x = self.head(x)
|
||||
return x
|
||||
|
||||
|
||||
def _create_vit_model(**model_kwargs):
|
||||
model = VisionTransformer3D(**model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_tiny_3d_patch4_32(**kwargs):
|
||||
model_kwargs = dict(img_size=32, patch_size=4, embed_dim=512,
|
||||
depth=6, num_heads=8, hidden_dim=512, num_classes=10, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_tiny_3d_patch16_224(**kwargs):
|
||||
model_kwargs = dict(patch_size=16, embed_dim=192,
|
||||
depth=12, num_heads=3, hidden_dim=768, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_tiny_3d_patch16_384(**kwargs):
|
||||
model_kwargs = dict(img_size=384, patch_size=16,
|
||||
embed_dim=192, depth=12, num_heads=3, hidden_dim=768, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_small_3d_patch16_224(**kwargs):
|
||||
model_kwargs = dict(patch_size=16, embed_dim=384,
|
||||
depth=12, num_heads=6, hidden_dim=1536, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_small_3d_patch16_384(**kwargs):
|
||||
model_kwargs = dict(img_size=384, patch_size=16,
|
||||
embed_dim=384, depth=12, num_heads=6, hidden_dim=1536, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_small_3d_patch32_224(**kwargs):
|
||||
model_kwargs = dict(patch_size=32, embed_dim=384,
|
||||
depth=12, num_heads=6, hidden_dim=1536, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_small_3d_patch32_384(**kwargs):
|
||||
model_kwargs = dict(img_size=384, patch_size=32,
|
||||
embed_dim=384, depth=12, num_heads=6, hidden_dim=1536, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_base_3d_patch16_224(**kwargs):
|
||||
model_kwargs = dict(patch_size=16, embed_dim=768,
|
||||
depth=12, num_heads=12, hidden_dim=3072, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_base_3d_patch16_384(**kwargs):
|
||||
model_kwargs = dict(img_size=384, patch_size=16,
|
||||
embed_dim=768, depth=12, num_heads=12, hidden_dim=3072, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_base_3d_patch32_224(**kwargs):
|
||||
model_kwargs = dict(patch_size=32, embed_dim=768,
|
||||
depth=12, num_heads=12, hidden_dim=3072, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_base_3d_patch32_384(**kwargs):
|
||||
model_kwargs = dict(img_size=384, patch_size=32,
|
||||
embed_dim=768, depth=12, num_heads=12, hidden_dim=3072, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_large_3d_patch16_224(**kwargs):
|
||||
model_kwargs = dict(patch_size=16, embed_dim=1024,
|
||||
depth=24, num_heads=16, hidden_dim=4096, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_large_3d_patch16_384(**kwargs):
|
||||
model_kwargs = dict(img_size=384, patch_size=16,
|
||||
embed_dim=1024, depth=24, num_heads=16, hidden_dim=4096, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_large_3d_patch32_224(**kwargs):
|
||||
model_kwargs = dict(patch_size=32, embed_dim=1024,
|
||||
depth=24, num_heads=16, hidden_dim=4096, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_large_3d_patch32_384(**kwargs):
|
||||
model_kwargs = dict(img_size=384, patch_size=32,
|
||||
embed_dim=1024, depth=24, num_heads=16, hidden_dim=4096, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
Reference in New Issue
Block a user