mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 03:20:52 +00:00
Hotfix/Colossalai layers (#92)
* optimized 1d layer apis; reorganized nn.layer modules; fixed tests * fixed 2.5d runtime issue * reworked split batch, now called in trainer.schedule.load_batch Co-authored-by: BoxiangW <45734921+BoxiangW@users.noreply.github.com>
This commit is contained in:
@@ -3,9 +3,8 @@ from typing import Callable
|
||||
|
||||
import torch
|
||||
from colossalai import nn as col_nn
|
||||
from colossalai.context import ParallelMode, seed
|
||||
from colossalai.nn.layer.utils import CheckpointModule
|
||||
from colossalai.registry import LAYERS, MODELS
|
||||
from colossalai.utils import checkpoint
|
||||
from torch import dtype, nn
|
||||
|
||||
__all__ = [
|
||||
@@ -72,8 +71,7 @@ class ViTEmbedding(nn.Module):
|
||||
dropout: float,
|
||||
dtype: dtype = None,
|
||||
flatten: bool = True,
|
||||
init_method: str = 'torch',
|
||||
tensor_parallel: str = None):
|
||||
init_method: str = 'torch'):
|
||||
super().__init__()
|
||||
self.patch_embed = col_nn.PatchEmbedding(img_size,
|
||||
patch_size,
|
||||
@@ -81,19 +79,17 @@ class ViTEmbedding(nn.Module):
|
||||
embedding_dim,
|
||||
dtype=dtype,
|
||||
flatten=flatten,
|
||||
tensor_parallel=tensor_parallel,
|
||||
**_init_rules[init_method]['embed'])
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.dropout = col_nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.patch_embed(x)
|
||||
with seed(ParallelMode.TENSOR):
|
||||
x = self.dropout(x)
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class ViTSelfAttention(nn.Module):
|
||||
class ViTSelfAttention(CheckpointModule):
|
||||
def __init__(self,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
@@ -102,27 +98,17 @@ class ViTSelfAttention(nn.Module):
|
||||
bias: bool = True,
|
||||
dtype: dtype = None,
|
||||
checkpoint: bool = False,
|
||||
init_method: str = 'torch',
|
||||
tensor_parallel: str = None):
|
||||
super().__init__()
|
||||
init_method: str = 'torch'):
|
||||
super().__init__(checkpoint)
|
||||
self.attention_head_size = dim // num_heads
|
||||
self.checkpoint = checkpoint
|
||||
self.tensor_parallel = tensor_parallel
|
||||
|
||||
self.query_key_value = col_nn.Linear(dim,
|
||||
3 * dim,
|
||||
dtype=dtype,
|
||||
bias=bias,
|
||||
tensor_parallel='1d_col' if tensor_parallel == '1d' else tensor_parallel,
|
||||
**_init_rules[init_method]['transformer'])
|
||||
self.attention_dropout = nn.Dropout(attention_dropout)
|
||||
self.dense = col_nn.Linear(dim,
|
||||
dim,
|
||||
dtype=dtype,
|
||||
bias=True,
|
||||
tensor_parallel='1d_row' if tensor_parallel == '1d' else tensor_parallel,
|
||||
**_init_rules[init_method]['transformer'])
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.attention_dropout = col_nn.Dropout(attention_dropout)
|
||||
self.dense = col_nn.Linear(dim, dim, dtype=dtype, bias=True, **_init_rules[init_method]['transformer'])
|
||||
self.dropout = col_nn.Dropout(dropout)
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
|
||||
def _forward(self, x):
|
||||
@@ -138,8 +124,7 @@ class ViTSelfAttention(nn.Module):
|
||||
x = torch.matmul(q, k.transpose(-1, -2))
|
||||
x = x / math.sqrt(self.attention_head_size)
|
||||
x = self.softmax(x)
|
||||
with seed(ParallelMode.TENSOR):
|
||||
x = self.attention_dropout(x)
|
||||
x = self.attention_dropout(x)
|
||||
|
||||
x = torch.matmul(x, v)
|
||||
x = x.transpose(1, 2)
|
||||
@@ -147,26 +132,13 @@ class ViTSelfAttention(nn.Module):
|
||||
x = x.reshape(new_context_layer_shape)
|
||||
|
||||
x = self.dense(x)
|
||||
if self.tensor_parallel == '1d':
|
||||
x = self.dropout(x)
|
||||
else:
|
||||
with seed(ParallelMode.TENSOR):
|
||||
x = self.dropout(x)
|
||||
x = self.dropout(x)
|
||||
|
||||
return x
|
||||
|
||||
def _checkpoint_forward(self, x):
|
||||
return checkpoint(self._forward, x)
|
||||
|
||||
def forward(self, x):
|
||||
if self.checkpoint:
|
||||
return self._checkpoint_forward(x)
|
||||
else:
|
||||
return self._forward(x)
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class ViTMLP(nn.Module):
|
||||
class ViTMLP(CheckpointModule):
|
||||
def __init__(self,
|
||||
dim: int,
|
||||
mlp_ratio: int,
|
||||
@@ -175,50 +147,30 @@ class ViTMLP(nn.Module):
|
||||
dtype: dtype = None,
|
||||
bias: bool = True,
|
||||
checkpoint: bool = False,
|
||||
init_method: str = 'torch',
|
||||
tensor_parallel: str = None):
|
||||
super().__init__()
|
||||
self.checkpoint = checkpoint
|
||||
self.tensor_parallel = tensor_parallel
|
||||
|
||||
init_method: str = 'torch'):
|
||||
super().__init__(checkpoint)
|
||||
self.dense_1 = col_nn.Linear(dim,
|
||||
mlp_ratio * dim,
|
||||
dtype=dtype,
|
||||
bias=bias,
|
||||
tensor_parallel='1d_col' if tensor_parallel == '1d' else tensor_parallel,
|
||||
**_init_rules[init_method]['transformer'])
|
||||
self.activation = activation
|
||||
self.dropout_1 = col_nn.Dropout(dropout)
|
||||
self.dense_2 = col_nn.Linear(mlp_ratio * dim,
|
||||
dim,
|
||||
dtype=dtype,
|
||||
bias=bias,
|
||||
tensor_parallel='1d_row' if tensor_parallel == '1d' else tensor_parallel,
|
||||
**_init_rules[init_method]['transformer'])
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.dropout_2 = col_nn.Dropout(dropout)
|
||||
|
||||
def _forward(self, x):
|
||||
x = self.dense_1(x)
|
||||
x = self.activation(x)
|
||||
with seed(ParallelMode.TENSOR):
|
||||
x = self.dropout(x)
|
||||
x = self.dropout_1(x)
|
||||
x = self.dense_2(x)
|
||||
if self.tensor_parallel == '1d':
|
||||
x = self.dropout(x)
|
||||
else:
|
||||
with seed(ParallelMode.TENSOR):
|
||||
x = self.dropout(x)
|
||||
|
||||
x = self.dropout_2(x)
|
||||
return x
|
||||
|
||||
def _checkpoint_forward(self, x):
|
||||
return checkpoint(self._forward, x)
|
||||
|
||||
def forward(self, x):
|
||||
if self.checkpoint:
|
||||
return self._checkpoint_forward(x)
|
||||
else:
|
||||
return self._forward(x)
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class ViTHead(nn.Module):
|
||||
@@ -228,19 +180,14 @@ class ViTHead(nn.Module):
|
||||
representation_size: int = None,
|
||||
dtype: dtype = None,
|
||||
bias: bool = True,
|
||||
init_method: str = 'torch',
|
||||
tensor_parallel: str = None):
|
||||
init_method: str = 'torch'):
|
||||
super().__init__()
|
||||
if representation_size:
|
||||
tensor_parallel_kwargs = {'tensor_parallel': '1d_col' if tensor_parallel == '1d' else tensor_parallel}
|
||||
if tensor_parallel == '1d':
|
||||
tensor_parallel_kwargs['gather_output'] = True
|
||||
self.representation = col_nn.Linear(dim,
|
||||
representation_size,
|
||||
bias=bias,
|
||||
dtype=dtype,
|
||||
**_init_rules[init_method]['head'],
|
||||
**tensor_parallel_kwargs)
|
||||
**_init_rules[init_method]['head'])
|
||||
else:
|
||||
self.representation = None
|
||||
representation_size = dim
|
||||
@@ -249,7 +196,6 @@ class ViTHead(nn.Module):
|
||||
num_classes,
|
||||
dtype=dtype,
|
||||
bias=bias,
|
||||
tensor_parallel=tensor_parallel,
|
||||
**_init_rules[init_method]['head'])
|
||||
|
||||
def forward(self, x):
|
||||
@@ -273,10 +219,9 @@ class ViTBlock(nn.Module):
|
||||
dtype: dtype = None,
|
||||
bias: bool = True,
|
||||
checkpoint: bool = False,
|
||||
init_method: str = 'torch',
|
||||
tensor_parallel: str = None):
|
||||
init_method: str = 'torch'):
|
||||
super().__init__()
|
||||
self.norm1 = col_nn.LayerNorm(normalized_shape=dim, eps=1e-6, dtype=dtype, tensor_parallel=tensor_parallel)
|
||||
self.norm1 = col_nn.LayerNorm(normalized_shape=dim, eps=1e-6, dtype=dtype)
|
||||
self.attn = ViTSelfAttention(dim=dim,
|
||||
num_heads=num_heads,
|
||||
attention_dropout=attention_dropout,
|
||||
@@ -284,10 +229,9 @@ class ViTBlock(nn.Module):
|
||||
bias=bias,
|
||||
dtype=dtype,
|
||||
checkpoint=checkpoint,
|
||||
init_method=init_method,
|
||||
tensor_parallel=tensor_parallel)
|
||||
init_method=init_method)
|
||||
self.drop_path = col_nn.DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.norm2 = col_nn.LayerNorm(normalized_shape=dim, eps=1e-6, dtype=dtype, tensor_parallel=tensor_parallel)
|
||||
self.norm2 = col_nn.LayerNorm(normalized_shape=dim, eps=1e-6, dtype=dtype)
|
||||
self.mlp = ViTMLP(dim=dim,
|
||||
mlp_ratio=mlp_ratio,
|
||||
activation=activation,
|
||||
@@ -295,8 +239,7 @@ class ViTBlock(nn.Module):
|
||||
dtype=dtype,
|
||||
bias=bias,
|
||||
checkpoint=checkpoint,
|
||||
init_method=init_method,
|
||||
tensor_parallel=tensor_parallel)
|
||||
init_method=init_method)
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.drop_path(self.attn(self.norm1(x)))
|
||||
@@ -323,20 +266,16 @@ class VisionTransformer(nn.Module):
|
||||
dtype: dtype = None,
|
||||
bias: bool = True,
|
||||
checkpoint: bool = False,
|
||||
init_method: str = 'torch',
|
||||
tensor_parallel: str = None):
|
||||
init_method: str = 'torch'):
|
||||
super().__init__()
|
||||
|
||||
embed = ViTEmbedding(
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
in_chans=in_chans,
|
||||
embedding_dim=dim,
|
||||
dropout=dropout,
|
||||
dtype=dtype,
|
||||
init_method=init_method,
|
||||
tensor_parallel=tensor_parallel,
|
||||
)
|
||||
embed = ViTEmbedding(img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
in_chans=in_chans,
|
||||
embedding_dim=dim,
|
||||
dropout=dropout,
|
||||
dtype=dtype,
|
||||
init_method=init_method)
|
||||
|
||||
# stochastic depth decay rule
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path, depth)]
|
||||
@@ -353,26 +292,17 @@ class VisionTransformer(nn.Module):
|
||||
bias=bias,
|
||||
checkpoint=checkpoint,
|
||||
init_method=init_method,
|
||||
tensor_parallel=tensor_parallel,
|
||||
) for i in range(depth)
|
||||
]
|
||||
|
||||
norm = col_nn.LayerNorm(
|
||||
normalized_shape=dim,
|
||||
eps=1e-6,
|
||||
dtype=dtype,
|
||||
tensor_parallel=tensor_parallel,
|
||||
)
|
||||
norm = col_nn.LayerNorm(normalized_shape=dim, eps=1e-6, dtype=dtype)
|
||||
|
||||
head = ViTHead(
|
||||
dim=dim,
|
||||
num_classes=num_classes,
|
||||
representation_size=representation_size,
|
||||
dtype=dtype,
|
||||
bias=bias,
|
||||
init_method=init_method,
|
||||
tensor_parallel=tensor_parallel,
|
||||
)
|
||||
head = ViTHead(dim=dim,
|
||||
num_classes=num_classes,
|
||||
representation_size=representation_size,
|
||||
dtype=dtype,
|
||||
bias=bias,
|
||||
init_method=init_method)
|
||||
|
||||
self.layers = nn.Sequential(
|
||||
embed,
|
||||
|
Reference in New Issue
Block a user