[model checkpoint] updated saving/loading for 2d layers (#595)

This commit is contained in:
アマデウス 2022-04-01 16:50:34 +08:00 committed by GitHub
parent cd13b63832
commit 7636d518e1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,4 +1,5 @@
import math
from collections import OrderedDict
from typing import Callable
import torch
@ -10,13 +11,15 @@ from colossalai.core import global_context as gpc
from colossalai.global_variables import tensor_parallel_env as env
from colossalai.nn import init as init
from colossalai.registry import LAYERS
from colossalai.utils.checkpointing import gather_tensor_parallel_state_dict, partition_tensor_parallel_state_dict
from colossalai.utils.cuda import get_current_device
from torch import Tensor
from torch.nn import Parameter
from ..base_layer import ParallelLayer
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
from ._operation import *
from ._operation import (Matmul_AB_2D, Matmul_ABT_2D, add_bias_2d, all_gather_tensor_2d, classifier_2d, layernorm_2d,
reduce_scatter_tensor_2d, split_tensor_2d)
from ._utils import assert_summa_initialization, get_summa_dim_from_env
@ -39,6 +42,7 @@ class Linear2D(ParallelLayer):
More details about ``initializer`` please refer to
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
"""
def __init__(self,
in_features: int,
out_features: int,
@ -90,6 +94,91 @@ class Linear2D(ParallelLayer):
if self.bias is not None:
bias_initializer(self.bias, fan_in=fan_in)
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict()
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
# weight
weight = state_dict.pop(weight_key, None)
if weight is not None:
local_state[weight_key] = weight.transpose(0, 1)
# bias
if self.bias is not None:
bias = state_dict.pop(bias_key, None)
if bias is not None:
local_state[bias_key] = bias
# partition in row groups
if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0:
local_state = partition_tensor_parallel_state_dict(
local_state,
ParallelMode.PARALLEL_2D_ROW,
dims={
weight_key: -1,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: True
},
)
# partition in column groups
local_state = partition_tensor_parallel_state_dict(
local_state,
ParallelMode.PARALLEL_2D_COL,
dims={
weight_key: 0,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: True
},
)
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
def _save_to_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
local_state = OrderedDict({weight_key: self.weight})
if self.bias is not None:
local_state[bias_key] = self.bias
# gather in column groups
local_state = gather_tensor_parallel_state_dict(
local_state,
ParallelMode.PARALLEL_2D_COL,
dims={
weight_key: 0,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: True
},
keep_vars=keep_vars,
)
# gather in row groups
if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0:
local_state = gather_tensor_parallel_state_dict(
local_state,
ParallelMode.PARALLEL_2D_ROW,
dims={
weight_key: -1,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: True
},
keep_vars=keep_vars,
)
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
local_state[weight_key] = local_state[weight_key].transpose(0, 1)
destination.update(local_state)
def forward(self, x: Tensor) -> Tensor:
# input: [m/q, n/q, k/q]
# output: [m/q, n/q, h/q]
@ -129,6 +218,7 @@ class LayerNorm2D(ParallelLayer):
eps (float, optional): a value added to the denominator for numerical stability, defaults to 1e-05.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
"""
def __init__(self, normalized_shape: int, eps: float = 1e-05, dtype=None):
super().__init__()
@ -148,14 +238,95 @@ class LayerNorm2D(ParallelLayer):
# create parameters
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
self.gamma = Parameter(torch.ones(self.partitioned_partition, **factory_kwargs))
self.beta = Parameter(torch.zeros(self.partitioned_partition, **factory_kwargs))
self.weight = Parameter(torch.ones(self.partitioned_partition, **factory_kwargs))
self.bias = Parameter(torch.zeros(self.partitioned_partition, **factory_kwargs))
self._set_tensor_parallel_attributes()
def _set_tensor_parallel_attributes(self):
set_tensor_parallel_attribute_by_partition(self.gamma, self.summa_dim**2)
set_tensor_parallel_attribute_by_partition(self.beta, self.summa_dim**2)
set_tensor_parallel_attribute_by_partition(self.weight, self.summa_dim**2)
set_tensor_parallel_attribute_by_partition(self.bias, self.summa_dim**2)
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict()
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
# weight
weight = state_dict.pop(weight_key, None)
if weight is not None:
local_state[weight_key] = weight
# bias
bias = state_dict.pop(bias_key, None)
if bias is not None:
local_state[bias_key] = bias
# partition in row groups
if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0:
local_state = partition_tensor_parallel_state_dict(
local_state,
ParallelMode.PARALLEL_2D_ROW,
dims={
weight_key: 0,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: True
},
)
# partition in column groups
local_state = partition_tensor_parallel_state_dict(
local_state,
ParallelMode.PARALLEL_2D_COL,
dims={
weight_key: 0,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: True
},
)
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
def _save_to_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
local_state = OrderedDict({weight_key: self.weight, bias_key: self.bias})
# gather in column groups
local_state = gather_tensor_parallel_state_dict(
local_state,
ParallelMode.PARALLEL_2D_COL,
dims={
weight_key: 0,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: True
},
keep_vars=keep_vars,
)
# gather in row groups
if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0:
local_state = gather_tensor_parallel_state_dict(
local_state,
ParallelMode.PARALLEL_2D_ROW,
dims={
weight_key: 0,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: True
},
keep_vars=keep_vars,
)
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
destination.update(local_state)
def forward(self, x: Tensor) -> Tensor:
with torch.no_grad():
@ -174,10 +345,10 @@ class LayerNorm2D(ParallelLayer):
output = layernorm_2d(x, E_x, Var_x, self.normalized_shape, ParallelMode.PARALLEL_2D_ROW,
ParallelMode.PARALLEL_2D_COL)
bias = add_bias_2d(None, self.beta, self.partitioned_partition, self.row_rank, self.col_rank,
bias = add_bias_2d(None, self.bias, self.partitioned_partition, self.row_rank, self.col_rank,
ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, True, self.data_parallel_rank,
self.pipeline_parallel_rank, self.pipeline_parallel_size, self.tensor_parallel_size)
scale = add_bias_2d(None, self.gamma, self.partitioned_partition, self.row_rank, self.col_rank,
scale = add_bias_2d(None, self.weight, self.partitioned_partition, self.row_rank, self.col_rank,
ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, True, self.data_parallel_rank,
self.pipeline_parallel_rank, self.pipeline_parallel_size, self.tensor_parallel_size)
output = torch.addcmul(bias, scale, output)
@ -205,6 +376,7 @@ class PatchEmbedding2D(ParallelLayer):
More details about ``initializer`` please refer to
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
"""
def __init__(self,
img_size: int,
patch_size: int,
@ -260,6 +432,120 @@ class PatchEmbedding2D(ParallelLayer):
bias_initializer(self.bias, fan_in=fan_in)
position_embed_initializer(self.pos_embed)
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict()
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
cls_token_key = prefix + 'cls_token'
pos_embed_key = prefix + 'pos_embed'
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
# weight
weight = state_dict.pop(weight_key, None)
if weight is not None:
local_state[weight_key] = weight
# bias
bias = state_dict.pop(bias_key, None)
if bias is not None:
local_state[bias_key] = bias
# cls token
cls_token = state_dict.pop(cls_token_key, None)
if cls_token is not None:
local_state[cls_token_key] = cls_token
# pos embed
pos_embed = state_dict.pop(pos_embed_key, None)
if pos_embed is not None:
local_state[pos_embed_key] = pos_embed
# partition in row groups
if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0:
local_state = partition_tensor_parallel_state_dict(
local_state,
ParallelMode.PARALLEL_2D_ROW,
dims={
weight_key: 0,
bias_key: 0,
cls_token_key: -1,
pos_embed_key: -1
},
partition_states={
weight_key: True,
bias_key: True,
cls_token_key: True,
pos_embed_key: True
},
)
# partition in column groups
local_state = partition_tensor_parallel_state_dict(
local_state,
ParallelMode.PARALLEL_2D_COL,
dims={
weight_key: 0,
bias_key: 0,
cls_token_key: -1,
pos_embed_key: -1
},
partition_states={
weight_key: True,
bias_key: True,
cls_token_key: True,
pos_embed_key: True
},
)
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
def _save_to_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
cls_token_key = prefix + 'cls_token'
pos_embed_key = prefix + 'pos_embed'
local_state = OrderedDict({
weight_key: self.weight,
bias_key: self.bias,
cls_token_key: self.cls_token,
pos_embed_key: self.pos_embed
})
# gather in column groups
local_state = gather_tensor_parallel_state_dict(
local_state,
ParallelMode.PARALLEL_2D_COL,
dims={
weight_key: 0,
bias_key: 0,
cls_token_key: -1,
pos_embed_key: -1
},
partition_states={
weight_key: True,
bias_key: True,
cls_token_key: True,
pos_embed_key: True
},
keep_vars=keep_vars,
)
# gather in row groups
if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0:
local_state = gather_tensor_parallel_state_dict(
local_state,
ParallelMode.PARALLEL_2D_ROW,
dims={
weight_key: 0,
bias_key: 0,
cls_token_key: -1,
pos_embed_key: -1
},
partition_states={
weight_key: True,
bias_key: True,
cls_token_key: True,
pos_embed_key: True
},
keep_vars=keep_vars,
)
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
destination.update(local_state)
def forward(self, input_: Tensor) -> Tensor:
input_ = split_tensor_2d(input_)
@ -313,6 +599,7 @@ class Embedding2D(ParallelLayer):
More details about initializer please refer to
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_
"""
def __init__(self,
num_embeddings: int,
embedding_dim: int,
@ -353,6 +640,57 @@ class Embedding2D(ParallelLayer):
with torch.no_grad():
self.weight[self.padding_idx].fill_(0)
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict()
weight_key = prefix + 'weight'
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
# weight
weight = state_dict.pop(weight_key, None)
if weight is not None:
local_state[weight_key] = weight
# partition in row groups
if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0:
local_state = partition_tensor_parallel_state_dict(
local_state,
ParallelMode.PARALLEL_2D_ROW,
dims={weight_key: -1},
partition_states={weight_key: True},
)
# partition in column groups
local_state = partition_tensor_parallel_state_dict(
local_state,
ParallelMode.PARALLEL_2D_COL,
dims={weight_key: -1},
partition_states={weight_key: True},
)
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
def _save_to_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight'
local_state = OrderedDict({weight_key: self.weight})
# gather in column groups
local_state = gather_tensor_parallel_state_dict(
local_state,
ParallelMode.PARALLEL_2D_COL,
dims={weight_key: -1},
partition_states={weight_key: True},
keep_vars=keep_vars,
)
# gather in row groups
if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0:
local_state = gather_tensor_parallel_state_dict(
local_state,
ParallelMode.PARALLEL_2D_ROW,
dims={weight_key: -1},
partition_states={weight_key: True},
keep_vars=keep_vars,
)
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
destination.update(local_state)
def forward(self, input_: Tensor) -> Tensor:
input_ = split_tensor_2d(input_)
@ -392,6 +730,7 @@ class VocabParallelEmbedding2D(torch.nn.Module):
More details about initializer please refer to
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
"""
def __init__(self,
num_embeddings: int,
embedding_dim: int,
@ -439,6 +778,57 @@ class VocabParallelEmbedding2D(torch.nn.Module):
with torch.no_grad():
self.weight[self.padding_idx - self.vocab_start_index].fill_(0)
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict()
weight_key = prefix + 'weight'
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
# weight
weight = state_dict.pop(weight_key, None)
if weight is not None:
local_state[weight_key] = weight
# partition in row groups
if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0:
local_state = partition_tensor_parallel_state_dict(
local_state,
ParallelMode.PARALLEL_2D_ROW,
dims={weight_key: -1},
partition_states={weight_key: True},
)
# partition in column groups
local_state = partition_tensor_parallel_state_dict(
local_state,
ParallelMode.PARALLEL_2D_COL,
dims={weight_key: 0},
partition_states={weight_key: True},
)
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
def _save_to_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight'
local_state = OrderedDict({weight_key: self.weight})
# gather in column groups
local_state = gather_tensor_parallel_state_dict(
local_state,
ParallelMode.PARALLEL_2D_COL,
dims={weight_key: 0},
partition_states={weight_key: True},
keep_vars=keep_vars,
)
# gather in row groups
if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0:
local_state = gather_tensor_parallel_state_dict(
local_state,
ParallelMode.PARALLEL_2D_ROW,
dims={weight_key: -1},
partition_states={weight_key: True},
keep_vars=keep_vars,
)
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
destination.update(local_state)
def forward(self, input_: Tensor) -> Tensor:
input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index)
masked_input = input_.clone() - self.vocab_start_index
@ -470,6 +860,7 @@ class Classifier2D(ParallelLayer):
More details about ``initializer`` please refer to
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
"""
def __init__(self,
in_features: int,
num_classes: int,
@ -522,6 +913,93 @@ class Classifier2D(ParallelLayer):
broadcast(self.bias, col_src_rank, ParallelMode.PARALLEL_2D_COL)
broadcast(self.bias, row_src_rank, ParallelMode.PARALLEL_2D_ROW)
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict()
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
# weight
if self.has_weight:
weight = state_dict.pop(weight_key, None)
if weight is not None:
local_state[weight_key] = weight
# bias
if self.bias is not None:
bias = state_dict.pop(bias_key, None)
if bias is not None:
local_state[bias_key] = bias
# partition in row groups
if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0:
local_state = partition_tensor_parallel_state_dict(
local_state,
ParallelMode.PARALLEL_2D_ROW,
dims={
weight_key: -1,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: False
},
)
# partition in column groups
local_state = partition_tensor_parallel_state_dict(
local_state,
ParallelMode.PARALLEL_2D_COL,
dims={
weight_key: -1,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: False
},
)
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
def _save_to_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
local_state = OrderedDict()
if self.has_weight:
local_state[weight_key] = self.weight
if self.bias is not None:
local_state[bias_key] = self.bias
# gather in column groups
local_state = gather_tensor_parallel_state_dict(
local_state,
ParallelMode.PARALLEL_2D_COL,
dims={
weight_key: -1,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: False
},
keep_vars=keep_vars,
)
# gather in row groups
if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0:
local_state = gather_tensor_parallel_state_dict(
local_state,
ParallelMode.PARALLEL_2D_ROW,
dims={
weight_key: -1,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: False
},
keep_vars=keep_vars,
)
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
destination.update(local_state)
def forward(self, input_: Tensor) -> Tensor:
out_shape = input_.shape[:-1] + (self.num_classes, )
@ -548,6 +1026,7 @@ class VocabParallelClassifier2D(ParallelLayer):
More details about ``initializer`` please refer to
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
"""
def __init__(self,
in_features: int,
num_classes: int,
@ -605,6 +1084,94 @@ class VocabParallelClassifier2D(ParallelLayer):
if self.bias is not None:
bias_initializer(self.bias, fan_in=fan_in)
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict()
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
# weight
if self.has_weight:
weight = state_dict.pop(weight_key, None)
if weight is not None:
local_state[weight_key] = weight
# bias
if self.bias is not None:
bias = state_dict.pop(bias_key, None)
if bias is not None:
local_state[bias_key] = bias
# partition in row groups
if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0:
local_state = partition_tensor_parallel_state_dict(
local_state,
ParallelMode.PARALLEL_2D_ROW,
dims={
weight_key: -1,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: True
},
)
# partition in column groups
local_state = partition_tensor_parallel_state_dict(
local_state,
ParallelMode.PARALLEL_2D_COL,
dims={
weight_key: 0,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: True
},
)
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
def _save_to_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
local_state = OrderedDict()
if self.has_weight:
local_state[weight_key] = self.weight
if self.bias is not None:
local_state[bias_key] = self.bias
# gather in column groups
local_state = gather_tensor_parallel_state_dict(
local_state,
ParallelMode.PARALLEL_2D_COL,
dims={
weight_key: 0,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: True
},
keep_vars=keep_vars,
)
# gather in row groups
if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0:
local_state = gather_tensor_parallel_state_dict(
local_state,
ParallelMode.PARALLEL_2D_ROW,
dims={
weight_key: -1,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: True
},
keep_vars=keep_vars,
)
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
local_state[weight_key] = local_state[weight_key].transpose(0, 1)
destination.update(local_state)
def forward(self, x: Tensor) -> Tensor:
# input: [m/q, n/q, k/q]
# output: [m/q, n/q, h/q]