mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-05 03:26:48 +00:00
[model checkpoint] updated saving/loading for 2.5d layers (#596)
This commit is contained in:
parent
6302069c0e
commit
93089ed708
@ -1,4 +1,5 @@
|
|||||||
import math
|
import math
|
||||||
|
from collections import OrderedDict
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -10,13 +11,15 @@ from colossalai.core import global_context as gpc
|
|||||||
from colossalai.global_variables import tensor_parallel_env as env
|
from colossalai.global_variables import tensor_parallel_env as env
|
||||||
from colossalai.nn import init as init
|
from colossalai.nn import init as init
|
||||||
from colossalai.registry import LAYERS
|
from colossalai.registry import LAYERS
|
||||||
|
from colossalai.utils.checkpointing import (broadcast_state_dict, gather_tensor_parallel_state_dict,
|
||||||
|
partition_tensor_parallel_state_dict)
|
||||||
from colossalai.utils.cuda import get_current_device
|
from colossalai.utils.cuda import get_current_device
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.nn import Parameter
|
from torch.nn import Parameter
|
||||||
|
|
||||||
from ..base_layer import ParallelLayer
|
from ..base_layer import ParallelLayer
|
||||||
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
|
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
|
||||||
from ._operation import (add_bias_2p5d, Matmul_AB_2p5D, Matmul_ABT_2p5D, all_gather_tensor_2p5d, classifier_2p5d,
|
from ._operation import (Matmul_AB_2p5D, Matmul_ABT_2p5D, add_bias_2p5d, all_gather_tensor_2p5d, classifier_2p5d,
|
||||||
layernorm_2p5d, reduce_scatter_tensor_2p5d, split_tensor_2p5d)
|
layernorm_2p5d, reduce_scatter_tensor_2p5d, split_tensor_2p5d)
|
||||||
from ._utils import assert_tesseract_initialization, get_tesseract_dim_dep_from_env
|
from ._utils import assert_tesseract_initialization, get_tesseract_dim_dep_from_env
|
||||||
|
|
||||||
@ -40,6 +43,7 @@ class Linear2p5D(ParallelLayer):
|
|||||||
More details about ``initializer`` please refer to
|
More details about ``initializer`` please refer to
|
||||||
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
|
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
in_features: int,
|
in_features: int,
|
||||||
out_features: int,
|
out_features: int,
|
||||||
@ -92,6 +96,96 @@ class Linear2p5D(ParallelLayer):
|
|||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
bias_initializer(self.bias, fan_in=fan_in)
|
bias_initializer(self.bias, fan_in=fan_in)
|
||||||
|
|
||||||
|
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||||
|
local_state = OrderedDict()
|
||||||
|
weight_key = prefix + 'weight'
|
||||||
|
bias_key = prefix + 'bias'
|
||||||
|
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||||
|
# weight
|
||||||
|
weight = state_dict.pop(weight_key, None)
|
||||||
|
if weight is not None:
|
||||||
|
local_state[weight_key] = weight.transpose(0, 1)
|
||||||
|
# bias
|
||||||
|
if self.bias is not None:
|
||||||
|
bias = state_dict.pop(bias_key, None)
|
||||||
|
if bias is not None:
|
||||||
|
local_state[bias_key] = bias
|
||||||
|
|
||||||
|
# broadcast in dep groups
|
||||||
|
if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) == 0 and \
|
||||||
|
gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) == 0:
|
||||||
|
broadcast_state_dict(local_state, ParallelMode.PARALLEL_2P5D_DEP)
|
||||||
|
# partition in column groups
|
||||||
|
if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) == 0:
|
||||||
|
local_state = partition_tensor_parallel_state_dict(
|
||||||
|
local_state,
|
||||||
|
ParallelMode.PARALLEL_2P5D_COL,
|
||||||
|
dims={
|
||||||
|
weight_key: 0,
|
||||||
|
bias_key: 0
|
||||||
|
},
|
||||||
|
partition_states={
|
||||||
|
weight_key: True,
|
||||||
|
bias_key: False
|
||||||
|
},
|
||||||
|
)
|
||||||
|
# partition in row groups
|
||||||
|
local_state = partition_tensor_parallel_state_dict(
|
||||||
|
local_state,
|
||||||
|
ParallelMode.PARALLEL_2P5D_ROW,
|
||||||
|
dims={
|
||||||
|
weight_key: -1,
|
||||||
|
bias_key: 0
|
||||||
|
},
|
||||||
|
partition_states={
|
||||||
|
weight_key: True,
|
||||||
|
bias_key: True
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
|
||||||
|
|
||||||
|
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||||
|
if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) == 0:
|
||||||
|
weight_key = prefix + 'weight'
|
||||||
|
bias_key = prefix + 'bias'
|
||||||
|
local_state = OrderedDict({weight_key: self.weight})
|
||||||
|
if self.bias is not None:
|
||||||
|
local_state[bias_key] = self.bias
|
||||||
|
|
||||||
|
# gather in row groups
|
||||||
|
local_state = gather_tensor_parallel_state_dict(
|
||||||
|
local_state,
|
||||||
|
ParallelMode.PARALLEL_2P5D_ROW,
|
||||||
|
dims={
|
||||||
|
weight_key: -1,
|
||||||
|
bias_key: 0
|
||||||
|
},
|
||||||
|
partition_states={
|
||||||
|
weight_key: True,
|
||||||
|
bias_key: True
|
||||||
|
},
|
||||||
|
keep_vars=keep_vars,
|
||||||
|
)
|
||||||
|
# gather in column groups
|
||||||
|
if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) == 0:
|
||||||
|
local_state = gather_tensor_parallel_state_dict(
|
||||||
|
local_state,
|
||||||
|
ParallelMode.PARALLEL_2P5D_COL,
|
||||||
|
dims={
|
||||||
|
weight_key: 0,
|
||||||
|
bias_key: 0
|
||||||
|
},
|
||||||
|
partition_states={
|
||||||
|
weight_key: True,
|
||||||
|
bias_key: False
|
||||||
|
},
|
||||||
|
keep_vars=keep_vars,
|
||||||
|
)
|
||||||
|
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||||
|
local_state[weight_key] = local_state[weight_key].transpose(0, 1)
|
||||||
|
destination.update(local_state)
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
# input: [m/dq, n/q, k/q]
|
# input: [m/dq, n/q, k/q]
|
||||||
# output: [m/dq, n/q, h/q]
|
# output: [m/dq, n/q, h/q]
|
||||||
@ -143,6 +237,7 @@ class LayerNorm2p5D(ParallelLayer):
|
|||||||
eps (float, optional): a value added to the denominator for numerical stability, defaults to 1e-05.
|
eps (float, optional): a value added to the denominator for numerical stability, defaults to 1e-05.
|
||||||
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
|
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, normalized_shape: int, eps: float = 1e-05, dtype=None):
|
def __init__(self, normalized_shape: int, eps: float = 1e-05, dtype=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -163,14 +258,95 @@ class LayerNorm2p5D(ParallelLayer):
|
|||||||
# create parameters
|
# create parameters
|
||||||
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
|
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
|
||||||
|
|
||||||
self.gamma = Parameter(torch.ones(self.partitioned_partition, **factory_kwargs))
|
self.weight = Parameter(torch.ones(self.partitioned_partition, **factory_kwargs))
|
||||||
self.beta = Parameter(torch.zeros(self.partitioned_partition, **factory_kwargs))
|
self.bias = Parameter(torch.zeros(self.partitioned_partition, **factory_kwargs))
|
||||||
|
|
||||||
self._set_tensor_parallel_attribute()
|
self._set_tensor_parallel_attribute()
|
||||||
|
|
||||||
def _set_tensor_parallel_attribute(self):
|
def _set_tensor_parallel_attribute(self):
|
||||||
set_tensor_parallel_attribute_by_partition(self.gamma, self.tesseract_dim)
|
set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dim)
|
||||||
set_tensor_parallel_attribute_by_partition(self.beta, self.tesseract_dim)
|
set_tensor_parallel_attribute_by_partition(self.bias, self.tesseract_dim)
|
||||||
|
|
||||||
|
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||||
|
local_state = OrderedDict()
|
||||||
|
weight_key = prefix + 'weight'
|
||||||
|
bias_key = prefix + 'bias'
|
||||||
|
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||||
|
# weight
|
||||||
|
weight = state_dict.pop(weight_key, None)
|
||||||
|
if weight is not None:
|
||||||
|
local_state[weight_key] = weight
|
||||||
|
# bias
|
||||||
|
bias = state_dict.pop(bias_key, None)
|
||||||
|
if bias is not None:
|
||||||
|
local_state[bias_key] = bias
|
||||||
|
|
||||||
|
# partition in row groups
|
||||||
|
if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) == 0:
|
||||||
|
local_state = partition_tensor_parallel_state_dict(
|
||||||
|
local_state,
|
||||||
|
ParallelMode.PARALLEL_2P5D_ROW,
|
||||||
|
dims={
|
||||||
|
weight_key: 0,
|
||||||
|
bias_key: 0
|
||||||
|
},
|
||||||
|
partition_states={
|
||||||
|
weight_key: True,
|
||||||
|
bias_key: True
|
||||||
|
},
|
||||||
|
)
|
||||||
|
# partition in column groups
|
||||||
|
local_state = partition_tensor_parallel_state_dict(
|
||||||
|
local_state,
|
||||||
|
ParallelMode.PARALLEL_2P5D_COL,
|
||||||
|
dims={
|
||||||
|
weight_key: 0,
|
||||||
|
bias_key: 0
|
||||||
|
},
|
||||||
|
partition_states={
|
||||||
|
weight_key: True,
|
||||||
|
bias_key: True
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
|
||||||
|
|
||||||
|
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||||
|
weight_key = prefix + 'weight'
|
||||||
|
bias_key = prefix + 'bias'
|
||||||
|
local_state = OrderedDict({weight_key: self.weight, bias_key: self.bias})
|
||||||
|
|
||||||
|
# gather in column groups
|
||||||
|
local_state = gather_tensor_parallel_state_dict(
|
||||||
|
local_state,
|
||||||
|
ParallelMode.PARALLEL_2P5D_COL,
|
||||||
|
dims={
|
||||||
|
weight_key: 0,
|
||||||
|
bias_key: 0
|
||||||
|
},
|
||||||
|
partition_states={
|
||||||
|
weight_key: True,
|
||||||
|
bias_key: True
|
||||||
|
},
|
||||||
|
keep_vars=keep_vars,
|
||||||
|
)
|
||||||
|
# gather in row groups
|
||||||
|
if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) == 0:
|
||||||
|
local_state = gather_tensor_parallel_state_dict(
|
||||||
|
local_state,
|
||||||
|
ParallelMode.PARALLEL_2P5D_ROW,
|
||||||
|
dims={
|
||||||
|
weight_key: 0,
|
||||||
|
bias_key: 0
|
||||||
|
},
|
||||||
|
partition_states={
|
||||||
|
weight_key: True,
|
||||||
|
bias_key: True
|
||||||
|
},
|
||||||
|
keep_vars=keep_vars,
|
||||||
|
)
|
||||||
|
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||||
|
destination.update(local_state)
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@ -188,11 +364,11 @@ class LayerNorm2p5D(ParallelLayer):
|
|||||||
Var_x = 1.0 / torch.sqrt(Var_x + self.variance_epsilon)
|
Var_x = 1.0 / torch.sqrt(Var_x + self.variance_epsilon)
|
||||||
|
|
||||||
output = layernorm_2p5d(x, E_x, Var_x, self.normalized_shape, ParallelMode.PARALLEL_2P5D_ROW)
|
output = layernorm_2p5d(x, E_x, Var_x, self.normalized_shape, ParallelMode.PARALLEL_2P5D_ROW)
|
||||||
bias = add_bias_2p5d(None, self.beta, self.partitioned_partition, self.tesseract_dim, self.row_rank,
|
bias = add_bias_2p5d(None, self.bias, self.partitioned_partition, self.tesseract_dim, self.row_rank,
|
||||||
self.col_rank, self.dep_rank, ParallelMode.PARALLEL_2P5D_COL, True,
|
self.col_rank, self.dep_rank, ParallelMode.PARALLEL_2P5D_COL, True,
|
||||||
self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size,
|
self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size,
|
||||||
self.tensor_parallel_size)
|
self.tensor_parallel_size)
|
||||||
scale = add_bias_2p5d(None, self.gamma, self.partitioned_partition, self.tesseract_dim, self.row_rank,
|
scale = add_bias_2p5d(None, self.weight, self.partitioned_partition, self.tesseract_dim, self.row_rank,
|
||||||
self.col_rank, self.dep_rank, ParallelMode.PARALLEL_2P5D_COL, True,
|
self.col_rank, self.dep_rank, ParallelMode.PARALLEL_2P5D_COL, True,
|
||||||
self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size,
|
self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size,
|
||||||
self.tensor_parallel_size)
|
self.tensor_parallel_size)
|
||||||
@ -221,6 +397,7 @@ class PatchEmbedding2p5D(ParallelLayer):
|
|||||||
More details about ``initializer`` please refer to
|
More details about ``initializer`` please refer to
|
||||||
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
|
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
img_size: int,
|
img_size: int,
|
||||||
patch_size: int,
|
patch_size: int,
|
||||||
@ -276,6 +453,120 @@ class PatchEmbedding2p5D(ParallelLayer):
|
|||||||
bias_initializer(self.bias, fan_in=fan_in)
|
bias_initializer(self.bias, fan_in=fan_in)
|
||||||
position_embed_initializer(self.pos_embed)
|
position_embed_initializer(self.pos_embed)
|
||||||
|
|
||||||
|
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||||
|
local_state = OrderedDict()
|
||||||
|
weight_key = prefix + 'weight'
|
||||||
|
bias_key = prefix + 'bias'
|
||||||
|
cls_token_key = prefix + 'cls_token'
|
||||||
|
pos_embed_key = prefix + 'pos_embed'
|
||||||
|
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||||
|
# weight
|
||||||
|
weight = state_dict.pop(weight_key, None)
|
||||||
|
if weight is not None:
|
||||||
|
local_state[weight_key] = weight
|
||||||
|
# bias
|
||||||
|
bias = state_dict.pop(bias_key, None)
|
||||||
|
if bias is not None:
|
||||||
|
local_state[bias_key] = bias
|
||||||
|
# cls token
|
||||||
|
cls_token = state_dict.pop(cls_token_key, None)
|
||||||
|
if cls_token is not None:
|
||||||
|
local_state[cls_token_key] = cls_token
|
||||||
|
# pos embed
|
||||||
|
pos_embed = state_dict.pop(pos_embed_key, None)
|
||||||
|
if pos_embed is not None:
|
||||||
|
local_state[pos_embed_key] = pos_embed
|
||||||
|
|
||||||
|
# partition in row groups
|
||||||
|
if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) == 0:
|
||||||
|
local_state = partition_tensor_parallel_state_dict(
|
||||||
|
local_state,
|
||||||
|
ParallelMode.PARALLEL_2P5D_ROW,
|
||||||
|
dims={
|
||||||
|
weight_key: 0,
|
||||||
|
bias_key: 0,
|
||||||
|
cls_token_key: -1,
|
||||||
|
pos_embed_key: -1
|
||||||
|
},
|
||||||
|
partition_states={
|
||||||
|
weight_key: True,
|
||||||
|
bias_key: True,
|
||||||
|
cls_token_key: True,
|
||||||
|
pos_embed_key: True
|
||||||
|
},
|
||||||
|
)
|
||||||
|
# partition in column groups
|
||||||
|
local_state = partition_tensor_parallel_state_dict(
|
||||||
|
local_state,
|
||||||
|
ParallelMode.PARALLEL_2P5D_COL,
|
||||||
|
dims={
|
||||||
|
weight_key: 0,
|
||||||
|
bias_key: 0,
|
||||||
|
cls_token_key: -1,
|
||||||
|
pos_embed_key: -1
|
||||||
|
},
|
||||||
|
partition_states={
|
||||||
|
weight_key: True,
|
||||||
|
bias_key: True,
|
||||||
|
cls_token_key: True,
|
||||||
|
pos_embed_key: True
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
|
||||||
|
|
||||||
|
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||||
|
weight_key = prefix + 'weight'
|
||||||
|
bias_key = prefix + 'bias'
|
||||||
|
cls_token_key = prefix + 'cls_token'
|
||||||
|
pos_embed_key = prefix + 'pos_embed'
|
||||||
|
local_state = OrderedDict({
|
||||||
|
weight_key: self.weight,
|
||||||
|
bias_key: self.bias,
|
||||||
|
cls_token_key: self.cls_token,
|
||||||
|
pos_embed_key: self.pos_embed
|
||||||
|
})
|
||||||
|
|
||||||
|
# gather in column groups
|
||||||
|
local_state = gather_tensor_parallel_state_dict(
|
||||||
|
local_state,
|
||||||
|
ParallelMode.PARALLEL_2P5D_COL,
|
||||||
|
dims={
|
||||||
|
weight_key: 0,
|
||||||
|
bias_key: 0,
|
||||||
|
cls_token_key: -1,
|
||||||
|
pos_embed_key: -1
|
||||||
|
},
|
||||||
|
partition_states={
|
||||||
|
weight_key: True,
|
||||||
|
bias_key: True,
|
||||||
|
cls_token_key: True,
|
||||||
|
pos_embed_key: True
|
||||||
|
},
|
||||||
|
keep_vars=keep_vars,
|
||||||
|
)
|
||||||
|
# gather in row groups
|
||||||
|
if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) == 0:
|
||||||
|
local_state = gather_tensor_parallel_state_dict(
|
||||||
|
local_state,
|
||||||
|
ParallelMode.PARALLEL_2P5D_ROW,
|
||||||
|
dims={
|
||||||
|
weight_key: 0,
|
||||||
|
bias_key: 0,
|
||||||
|
cls_token_key: -1,
|
||||||
|
pos_embed_key: -1
|
||||||
|
},
|
||||||
|
partition_states={
|
||||||
|
weight_key: True,
|
||||||
|
bias_key: True,
|
||||||
|
cls_token_key: True,
|
||||||
|
pos_embed_key: True
|
||||||
|
},
|
||||||
|
keep_vars=keep_vars,
|
||||||
|
)
|
||||||
|
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||||
|
destination.update(local_state)
|
||||||
|
|
||||||
def forward(self, input_: Tensor) -> Tensor:
|
def forward(self, input_: Tensor) -> Tensor:
|
||||||
input_ = split_tensor_2p5d(input_, 0)
|
input_ = split_tensor_2p5d(input_, 0)
|
||||||
|
|
||||||
@ -329,6 +620,7 @@ class Embedding2p5D(ParallelLayer):
|
|||||||
More details about initializer please refer to
|
More details about initializer please refer to
|
||||||
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_
|
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
num_embeddings: int,
|
num_embeddings: int,
|
||||||
embedding_dim: int,
|
embedding_dim: int,
|
||||||
@ -369,6 +661,57 @@ class Embedding2p5D(ParallelLayer):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
self.weight[self.padding_idx].fill_(0)
|
self.weight[self.padding_idx].fill_(0)
|
||||||
|
|
||||||
|
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||||
|
local_state = OrderedDict()
|
||||||
|
weight_key = prefix + 'weight'
|
||||||
|
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||||
|
# weight
|
||||||
|
weight = state_dict.pop(weight_key, None)
|
||||||
|
if weight is not None:
|
||||||
|
local_state[weight_key] = weight
|
||||||
|
|
||||||
|
# partition in row groups
|
||||||
|
if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) == 0:
|
||||||
|
local_state = partition_tensor_parallel_state_dict(
|
||||||
|
local_state,
|
||||||
|
ParallelMode.PARALLEL_2P5D_ROW,
|
||||||
|
dims={weight_key: -1},
|
||||||
|
partition_states={weight_key: True},
|
||||||
|
)
|
||||||
|
# partition in column groups
|
||||||
|
local_state = partition_tensor_parallel_state_dict(
|
||||||
|
local_state,
|
||||||
|
ParallelMode.PARALLEL_2P5D_COL,
|
||||||
|
dims={weight_key: -1},
|
||||||
|
partition_states={weight_key: True},
|
||||||
|
)
|
||||||
|
|
||||||
|
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
|
||||||
|
|
||||||
|
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||||
|
weight_key = prefix + 'weight'
|
||||||
|
local_state = OrderedDict({weight_key: self.weight})
|
||||||
|
|
||||||
|
# gather in column groups
|
||||||
|
local_state = gather_tensor_parallel_state_dict(
|
||||||
|
local_state,
|
||||||
|
ParallelMode.PARALLEL_2P5D_COL,
|
||||||
|
dims={weight_key: -1},
|
||||||
|
partition_states={weight_key: True},
|
||||||
|
keep_vars=keep_vars,
|
||||||
|
)
|
||||||
|
# gather in row groups
|
||||||
|
if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) == 0:
|
||||||
|
local_state = gather_tensor_parallel_state_dict(
|
||||||
|
local_state,
|
||||||
|
ParallelMode.PARALLEL_2P5D_ROW,
|
||||||
|
dims={weight_key: -1},
|
||||||
|
partition_states={weight_key: True},
|
||||||
|
keep_vars=keep_vars,
|
||||||
|
)
|
||||||
|
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||||
|
destination.update(local_state)
|
||||||
|
|
||||||
def forward(self, input_: Tensor) -> Tensor:
|
def forward(self, input_: Tensor) -> Tensor:
|
||||||
input_ = split_tensor_2p5d(input_, 0)
|
input_ = split_tensor_2p5d(input_, 0)
|
||||||
|
|
||||||
@ -409,6 +752,7 @@ class VocabParallelEmbedding2p5D(torch.nn.Module):
|
|||||||
More details about initializer please refer to
|
More details about initializer please refer to
|
||||||
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
|
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
num_embeddings: int,
|
num_embeddings: int,
|
||||||
embedding_dim: int,
|
embedding_dim: int,
|
||||||
@ -456,6 +800,57 @@ class VocabParallelEmbedding2p5D(torch.nn.Module):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
self.weight[self.padding_idx - self.vocab_start_index].fill_(0)
|
self.weight[self.padding_idx - self.vocab_start_index].fill_(0)
|
||||||
|
|
||||||
|
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||||
|
local_state = OrderedDict()
|
||||||
|
weight_key = prefix + 'weight'
|
||||||
|
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||||
|
# weight
|
||||||
|
weight = state_dict.pop(weight_key, None)
|
||||||
|
if weight is not None:
|
||||||
|
local_state[weight_key] = weight
|
||||||
|
|
||||||
|
# partition in row groups
|
||||||
|
if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) == 0:
|
||||||
|
local_state = partition_tensor_parallel_state_dict(
|
||||||
|
local_state,
|
||||||
|
ParallelMode.PARALLEL_2P5D_ROW,
|
||||||
|
dims={weight_key: -1},
|
||||||
|
partition_states={weight_key: True},
|
||||||
|
)
|
||||||
|
# partition in column groups
|
||||||
|
local_state = partition_tensor_parallel_state_dict(
|
||||||
|
local_state,
|
||||||
|
ParallelMode.PARALLEL_2P5D_COL,
|
||||||
|
dims={weight_key: 0},
|
||||||
|
partition_states={weight_key: True},
|
||||||
|
)
|
||||||
|
|
||||||
|
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
|
||||||
|
|
||||||
|
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||||
|
weight_key = prefix + 'weight'
|
||||||
|
local_state = OrderedDict({weight_key: self.weight})
|
||||||
|
|
||||||
|
# gather in column groups
|
||||||
|
local_state = gather_tensor_parallel_state_dict(
|
||||||
|
local_state,
|
||||||
|
ParallelMode.PARALLEL_2P5D_COL,
|
||||||
|
dims={weight_key: 0},
|
||||||
|
partition_states={weight_key: True},
|
||||||
|
keep_vars=keep_vars,
|
||||||
|
)
|
||||||
|
# gather in row groups
|
||||||
|
if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) == 0:
|
||||||
|
local_state = gather_tensor_parallel_state_dict(
|
||||||
|
local_state,
|
||||||
|
ParallelMode.PARALLEL_2P5D_ROW,
|
||||||
|
dims={weight_key: -1},
|
||||||
|
partition_states={weight_key: True},
|
||||||
|
keep_vars=keep_vars,
|
||||||
|
)
|
||||||
|
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||||
|
destination.update(local_state)
|
||||||
|
|
||||||
def forward(self, input_: Tensor) -> Tensor:
|
def forward(self, input_: Tensor) -> Tensor:
|
||||||
# Build the mask.
|
# Build the mask.
|
||||||
input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index)
|
input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index)
|
||||||
@ -491,6 +886,7 @@ class Classifier2p5D(ParallelLayer):
|
|||||||
More details about ``initializer`` please refer to
|
More details about ``initializer`` please refer to
|
||||||
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
|
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
in_features: int,
|
in_features: int,
|
||||||
num_classes: int,
|
num_classes: int,
|
||||||
@ -544,6 +940,93 @@ class Classifier2p5D(ParallelLayer):
|
|||||||
broadcast(self.bias, col_src_rank, ParallelMode.PARALLEL_2P5D_COL)
|
broadcast(self.bias, col_src_rank, ParallelMode.PARALLEL_2P5D_COL)
|
||||||
broadcast(self.bias, row_src_rank, ParallelMode.PARALLEL_2P5D_ROW)
|
broadcast(self.bias, row_src_rank, ParallelMode.PARALLEL_2P5D_ROW)
|
||||||
|
|
||||||
|
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||||
|
local_state = OrderedDict()
|
||||||
|
weight_key = prefix + 'weight'
|
||||||
|
bias_key = prefix + 'bias'
|
||||||
|
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||||
|
# weight
|
||||||
|
if self.has_weight:
|
||||||
|
weight = state_dict.pop(weight_key, None)
|
||||||
|
if weight is not None:
|
||||||
|
local_state[weight_key] = weight
|
||||||
|
# bias
|
||||||
|
if self.bias is not None:
|
||||||
|
bias = state_dict.pop(bias_key, None)
|
||||||
|
if bias is not None:
|
||||||
|
local_state[bias_key] = bias
|
||||||
|
|
||||||
|
# partition in row groups
|
||||||
|
if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) == 0:
|
||||||
|
local_state = partition_tensor_parallel_state_dict(
|
||||||
|
local_state,
|
||||||
|
ParallelMode.PARALLEL_2P5D_ROW,
|
||||||
|
dims={
|
||||||
|
weight_key: -1,
|
||||||
|
bias_key: 0
|
||||||
|
},
|
||||||
|
partition_states={
|
||||||
|
weight_key: True,
|
||||||
|
bias_key: False
|
||||||
|
},
|
||||||
|
)
|
||||||
|
# partition in column groups
|
||||||
|
local_state = partition_tensor_parallel_state_dict(
|
||||||
|
local_state,
|
||||||
|
ParallelMode.PARALLEL_2P5D_COL,
|
||||||
|
dims={
|
||||||
|
weight_key: -1,
|
||||||
|
bias_key: 0
|
||||||
|
},
|
||||||
|
partition_states={
|
||||||
|
weight_key: True,
|
||||||
|
bias_key: False
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
|
||||||
|
|
||||||
|
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||||
|
weight_key = prefix + 'weight'
|
||||||
|
bias_key = prefix + 'bias'
|
||||||
|
local_state = OrderedDict()
|
||||||
|
if self.has_weight:
|
||||||
|
local_state[weight_key] = self.weight
|
||||||
|
if self.bias is not None:
|
||||||
|
local_state[bias_key] = self.bias
|
||||||
|
|
||||||
|
# gather in column groups
|
||||||
|
local_state = gather_tensor_parallel_state_dict(
|
||||||
|
local_state,
|
||||||
|
ParallelMode.PARALLEL_2P5D_COL,
|
||||||
|
dims={
|
||||||
|
weight_key: -1,
|
||||||
|
bias_key: 0
|
||||||
|
},
|
||||||
|
partition_states={
|
||||||
|
weight_key: True,
|
||||||
|
bias_key: False
|
||||||
|
},
|
||||||
|
keep_vars=keep_vars,
|
||||||
|
)
|
||||||
|
# gather in row groups
|
||||||
|
if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) == 0:
|
||||||
|
local_state = gather_tensor_parallel_state_dict(
|
||||||
|
local_state,
|
||||||
|
ParallelMode.PARALLEL_2P5D_ROW,
|
||||||
|
dims={
|
||||||
|
weight_key: -1,
|
||||||
|
bias_key: 0
|
||||||
|
},
|
||||||
|
partition_states={
|
||||||
|
weight_key: True,
|
||||||
|
bias_key: False
|
||||||
|
},
|
||||||
|
keep_vars=keep_vars,
|
||||||
|
)
|
||||||
|
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||||
|
destination.update(local_state)
|
||||||
|
|
||||||
def forward(self, input_: Tensor) -> Tensor:
|
def forward(self, input_: Tensor) -> Tensor:
|
||||||
out_shape = input_.shape[:-1] + (self.num_classes, )
|
out_shape = input_.shape[:-1] + (self.num_classes, )
|
||||||
|
|
||||||
@ -571,6 +1054,7 @@ class VocabParallelClassifier2p5D(ParallelLayer):
|
|||||||
More details about ``initializer`` please refer to
|
More details about ``initializer`` please refer to
|
||||||
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
|
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
in_features: int,
|
in_features: int,
|
||||||
num_classes: int,
|
num_classes: int,
|
||||||
@ -629,6 +1113,52 @@ class VocabParallelClassifier2p5D(ParallelLayer):
|
|||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
bias_initializer(self.bias, fan_in=fan_in)
|
bias_initializer(self.bias, fan_in=fan_in)
|
||||||
|
|
||||||
|
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||||
|
local_state = OrderedDict()
|
||||||
|
weight_key = prefix + 'weight'
|
||||||
|
bias_key = prefix + 'bias'
|
||||||
|
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||||
|
# weight
|
||||||
|
if self.has_weight:
|
||||||
|
weight = state_dict.pop(weight_key, None)
|
||||||
|
if weight is not None:
|
||||||
|
local_state[weight_key] = weight
|
||||||
|
# bias
|
||||||
|
if self.bias is not None:
|
||||||
|
bias = state_dict.pop(bias_key, None)
|
||||||
|
if bias is not None:
|
||||||
|
local_state[bias_key] = bias
|
||||||
|
|
||||||
|
# partition in row groups
|
||||||
|
if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) == 0:
|
||||||
|
local_state = partition_tensor_parallel_state_dict(
|
||||||
|
local_state,
|
||||||
|
ParallelMode.PARALLEL_2P5D_ROW,
|
||||||
|
dims={
|
||||||
|
weight_key: -1,
|
||||||
|
bias_key: 0
|
||||||
|
},
|
||||||
|
partition_states={
|
||||||
|
weight_key: True,
|
||||||
|
bias_key: True
|
||||||
|
},
|
||||||
|
)
|
||||||
|
# partition in column groups
|
||||||
|
local_state = partition_tensor_parallel_state_dict(
|
||||||
|
local_state,
|
||||||
|
ParallelMode.PARALLEL_2P5D_COL,
|
||||||
|
dims={
|
||||||
|
weight_key: 0,
|
||||||
|
bias_key: 0
|
||||||
|
},
|
||||||
|
partition_states={
|
||||||
|
weight_key: True,
|
||||||
|
bias_key: True
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
# input: [m/dq, n/q, k/q]
|
# input: [m/dq, n/q, k/q]
|
||||||
# output: [m/dq, n/q, h/q]
|
# output: [m/dq, n/q, h/q]
|
||||||
|
Loading…
Reference in New Issue
Block a user