mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-19 09:51:18 +00:00
[model checkpoint] updated saving/loading for 2d layers (#595)
This commit is contained in:
parent
cd13b63832
commit
7636d518e1
@ -1,4 +1,5 @@
|
||||
import math
|
||||
from collections import OrderedDict
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
@ -10,13 +11,15 @@ from colossalai.core import global_context as gpc
|
||||
from colossalai.global_variables import tensor_parallel_env as env
|
||||
from colossalai.nn import init as init
|
||||
from colossalai.registry import LAYERS
|
||||
from colossalai.utils.checkpointing import gather_tensor_parallel_state_dict, partition_tensor_parallel_state_dict
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from torch import Tensor
|
||||
from torch.nn import Parameter
|
||||
|
||||
from ..base_layer import ParallelLayer
|
||||
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
|
||||
from ._operation import *
|
||||
from ._operation import (Matmul_AB_2D, Matmul_ABT_2D, add_bias_2d, all_gather_tensor_2d, classifier_2d, layernorm_2d,
|
||||
reduce_scatter_tensor_2d, split_tensor_2d)
|
||||
from ._utils import assert_summa_initialization, get_summa_dim_from_env
|
||||
|
||||
|
||||
@ -39,6 +42,7 @@ class Linear2D(ParallelLayer):
|
||||
More details about ``initializer`` please refer to
|
||||
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
@ -90,6 +94,91 @@ class Linear2D(ParallelLayer):
|
||||
if self.bias is not None:
|
||||
bias_initializer(self.bias, fan_in=fan_in)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
# weight
|
||||
weight = state_dict.pop(weight_key, None)
|
||||
if weight is not None:
|
||||
local_state[weight_key] = weight.transpose(0, 1)
|
||||
# bias
|
||||
if self.bias is not None:
|
||||
bias = state_dict.pop(bias_key, None)
|
||||
if bias is not None:
|
||||
local_state[bias_key] = bias
|
||||
|
||||
# partition in row groups
|
||||
if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0:
|
||||
local_state = partition_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
ParallelMode.PARALLEL_2D_ROW,
|
||||
dims={
|
||||
weight_key: -1,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: True
|
||||
},
|
||||
)
|
||||
# partition in column groups
|
||||
local_state = partition_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
ParallelMode.PARALLEL_2D_COL,
|
||||
dims={
|
||||
weight_key: 0,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: True
|
||||
},
|
||||
)
|
||||
|
||||
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
local_state = OrderedDict({weight_key: self.weight})
|
||||
if self.bias is not None:
|
||||
local_state[bias_key] = self.bias
|
||||
|
||||
# gather in column groups
|
||||
local_state = gather_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
ParallelMode.PARALLEL_2D_COL,
|
||||
dims={
|
||||
weight_key: 0,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: True
|
||||
},
|
||||
keep_vars=keep_vars,
|
||||
)
|
||||
# gather in row groups
|
||||
if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0:
|
||||
local_state = gather_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
ParallelMode.PARALLEL_2D_ROW,
|
||||
dims={
|
||||
weight_key: -1,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: True
|
||||
},
|
||||
keep_vars=keep_vars,
|
||||
)
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
local_state[weight_key] = local_state[weight_key].transpose(0, 1)
|
||||
destination.update(local_state)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
# input: [m/q, n/q, k/q]
|
||||
# output: [m/q, n/q, h/q]
|
||||
@ -129,6 +218,7 @@ class LayerNorm2D(ParallelLayer):
|
||||
eps (float, optional): a value added to the denominator for numerical stability, defaults to 1e-05.
|
||||
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self, normalized_shape: int, eps: float = 1e-05, dtype=None):
|
||||
super().__init__()
|
||||
|
||||
@ -148,14 +238,95 @@ class LayerNorm2D(ParallelLayer):
|
||||
# create parameters
|
||||
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
|
||||
|
||||
self.gamma = Parameter(torch.ones(self.partitioned_partition, **factory_kwargs))
|
||||
self.beta = Parameter(torch.zeros(self.partitioned_partition, **factory_kwargs))
|
||||
self.weight = Parameter(torch.ones(self.partitioned_partition, **factory_kwargs))
|
||||
self.bias = Parameter(torch.zeros(self.partitioned_partition, **factory_kwargs))
|
||||
|
||||
self._set_tensor_parallel_attributes()
|
||||
|
||||
def _set_tensor_parallel_attributes(self):
|
||||
set_tensor_parallel_attribute_by_partition(self.gamma, self.summa_dim**2)
|
||||
set_tensor_parallel_attribute_by_partition(self.beta, self.summa_dim**2)
|
||||
set_tensor_parallel_attribute_by_partition(self.weight, self.summa_dim**2)
|
||||
set_tensor_parallel_attribute_by_partition(self.bias, self.summa_dim**2)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
# weight
|
||||
weight = state_dict.pop(weight_key, None)
|
||||
if weight is not None:
|
||||
local_state[weight_key] = weight
|
||||
# bias
|
||||
bias = state_dict.pop(bias_key, None)
|
||||
if bias is not None:
|
||||
local_state[bias_key] = bias
|
||||
|
||||
# partition in row groups
|
||||
if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0:
|
||||
local_state = partition_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
ParallelMode.PARALLEL_2D_ROW,
|
||||
dims={
|
||||
weight_key: 0,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: True
|
||||
},
|
||||
)
|
||||
# partition in column groups
|
||||
local_state = partition_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
ParallelMode.PARALLEL_2D_COL,
|
||||
dims={
|
||||
weight_key: 0,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: True
|
||||
},
|
||||
)
|
||||
|
||||
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
local_state = OrderedDict({weight_key: self.weight, bias_key: self.bias})
|
||||
|
||||
# gather in column groups
|
||||
local_state = gather_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
ParallelMode.PARALLEL_2D_COL,
|
||||
dims={
|
||||
weight_key: 0,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: True
|
||||
},
|
||||
keep_vars=keep_vars,
|
||||
)
|
||||
# gather in row groups
|
||||
if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0:
|
||||
local_state = gather_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
ParallelMode.PARALLEL_2D_ROW,
|
||||
dims={
|
||||
weight_key: 0,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: True
|
||||
},
|
||||
keep_vars=keep_vars,
|
||||
)
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
destination.update(local_state)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
with torch.no_grad():
|
||||
@ -174,10 +345,10 @@ class LayerNorm2D(ParallelLayer):
|
||||
|
||||
output = layernorm_2d(x, E_x, Var_x, self.normalized_shape, ParallelMode.PARALLEL_2D_ROW,
|
||||
ParallelMode.PARALLEL_2D_COL)
|
||||
bias = add_bias_2d(None, self.beta, self.partitioned_partition, self.row_rank, self.col_rank,
|
||||
bias = add_bias_2d(None, self.bias, self.partitioned_partition, self.row_rank, self.col_rank,
|
||||
ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, True, self.data_parallel_rank,
|
||||
self.pipeline_parallel_rank, self.pipeline_parallel_size, self.tensor_parallel_size)
|
||||
scale = add_bias_2d(None, self.gamma, self.partitioned_partition, self.row_rank, self.col_rank,
|
||||
scale = add_bias_2d(None, self.weight, self.partitioned_partition, self.row_rank, self.col_rank,
|
||||
ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, True, self.data_parallel_rank,
|
||||
self.pipeline_parallel_rank, self.pipeline_parallel_size, self.tensor_parallel_size)
|
||||
output = torch.addcmul(bias, scale, output)
|
||||
@ -205,6 +376,7 @@ class PatchEmbedding2D(ParallelLayer):
|
||||
More details about ``initializer`` please refer to
|
||||
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
img_size: int,
|
||||
patch_size: int,
|
||||
@ -260,6 +432,120 @@ class PatchEmbedding2D(ParallelLayer):
|
||||
bias_initializer(self.bias, fan_in=fan_in)
|
||||
position_embed_initializer(self.pos_embed)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
cls_token_key = prefix + 'cls_token'
|
||||
pos_embed_key = prefix + 'pos_embed'
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
# weight
|
||||
weight = state_dict.pop(weight_key, None)
|
||||
if weight is not None:
|
||||
local_state[weight_key] = weight
|
||||
# bias
|
||||
bias = state_dict.pop(bias_key, None)
|
||||
if bias is not None:
|
||||
local_state[bias_key] = bias
|
||||
# cls token
|
||||
cls_token = state_dict.pop(cls_token_key, None)
|
||||
if cls_token is not None:
|
||||
local_state[cls_token_key] = cls_token
|
||||
# pos embed
|
||||
pos_embed = state_dict.pop(pos_embed_key, None)
|
||||
if pos_embed is not None:
|
||||
local_state[pos_embed_key] = pos_embed
|
||||
|
||||
# partition in row groups
|
||||
if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0:
|
||||
local_state = partition_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
ParallelMode.PARALLEL_2D_ROW,
|
||||
dims={
|
||||
weight_key: 0,
|
||||
bias_key: 0,
|
||||
cls_token_key: -1,
|
||||
pos_embed_key: -1
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: True,
|
||||
cls_token_key: True,
|
||||
pos_embed_key: True
|
||||
},
|
||||
)
|
||||
# partition in column groups
|
||||
local_state = partition_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
ParallelMode.PARALLEL_2D_COL,
|
||||
dims={
|
||||
weight_key: 0,
|
||||
bias_key: 0,
|
||||
cls_token_key: -1,
|
||||
pos_embed_key: -1
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: True,
|
||||
cls_token_key: True,
|
||||
pos_embed_key: True
|
||||
},
|
||||
)
|
||||
|
||||
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
cls_token_key = prefix + 'cls_token'
|
||||
pos_embed_key = prefix + 'pos_embed'
|
||||
local_state = OrderedDict({
|
||||
weight_key: self.weight,
|
||||
bias_key: self.bias,
|
||||
cls_token_key: self.cls_token,
|
||||
pos_embed_key: self.pos_embed
|
||||
})
|
||||
|
||||
# gather in column groups
|
||||
local_state = gather_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
ParallelMode.PARALLEL_2D_COL,
|
||||
dims={
|
||||
weight_key: 0,
|
||||
bias_key: 0,
|
||||
cls_token_key: -1,
|
||||
pos_embed_key: -1
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: True,
|
||||
cls_token_key: True,
|
||||
pos_embed_key: True
|
||||
},
|
||||
keep_vars=keep_vars,
|
||||
)
|
||||
# gather in row groups
|
||||
if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0:
|
||||
local_state = gather_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
ParallelMode.PARALLEL_2D_ROW,
|
||||
dims={
|
||||
weight_key: 0,
|
||||
bias_key: 0,
|
||||
cls_token_key: -1,
|
||||
pos_embed_key: -1
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: True,
|
||||
cls_token_key: True,
|
||||
pos_embed_key: True
|
||||
},
|
||||
keep_vars=keep_vars,
|
||||
)
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
destination.update(local_state)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
input_ = split_tensor_2d(input_)
|
||||
|
||||
@ -313,6 +599,7 @@ class Embedding2D(ParallelLayer):
|
||||
More details about initializer please refer to
|
||||
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
@ -353,6 +640,57 @@ class Embedding2D(ParallelLayer):
|
||||
with torch.no_grad():
|
||||
self.weight[self.padding_idx].fill_(0)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
# weight
|
||||
weight = state_dict.pop(weight_key, None)
|
||||
if weight is not None:
|
||||
local_state[weight_key] = weight
|
||||
|
||||
# partition in row groups
|
||||
if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0:
|
||||
local_state = partition_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
ParallelMode.PARALLEL_2D_ROW,
|
||||
dims={weight_key: -1},
|
||||
partition_states={weight_key: True},
|
||||
)
|
||||
# partition in column groups
|
||||
local_state = partition_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
ParallelMode.PARALLEL_2D_COL,
|
||||
dims={weight_key: -1},
|
||||
partition_states={weight_key: True},
|
||||
)
|
||||
|
||||
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
local_state = OrderedDict({weight_key: self.weight})
|
||||
|
||||
# gather in column groups
|
||||
local_state = gather_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
ParallelMode.PARALLEL_2D_COL,
|
||||
dims={weight_key: -1},
|
||||
partition_states={weight_key: True},
|
||||
keep_vars=keep_vars,
|
||||
)
|
||||
# gather in row groups
|
||||
if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0:
|
||||
local_state = gather_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
ParallelMode.PARALLEL_2D_ROW,
|
||||
dims={weight_key: -1},
|
||||
partition_states={weight_key: True},
|
||||
keep_vars=keep_vars,
|
||||
)
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
destination.update(local_state)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
input_ = split_tensor_2d(input_)
|
||||
|
||||
@ -392,6 +730,7 @@ class VocabParallelEmbedding2D(torch.nn.Module):
|
||||
More details about initializer please refer to
|
||||
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
@ -439,6 +778,57 @@ class VocabParallelEmbedding2D(torch.nn.Module):
|
||||
with torch.no_grad():
|
||||
self.weight[self.padding_idx - self.vocab_start_index].fill_(0)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
# weight
|
||||
weight = state_dict.pop(weight_key, None)
|
||||
if weight is not None:
|
||||
local_state[weight_key] = weight
|
||||
|
||||
# partition in row groups
|
||||
if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0:
|
||||
local_state = partition_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
ParallelMode.PARALLEL_2D_ROW,
|
||||
dims={weight_key: -1},
|
||||
partition_states={weight_key: True},
|
||||
)
|
||||
# partition in column groups
|
||||
local_state = partition_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
ParallelMode.PARALLEL_2D_COL,
|
||||
dims={weight_key: 0},
|
||||
partition_states={weight_key: True},
|
||||
)
|
||||
|
||||
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
local_state = OrderedDict({weight_key: self.weight})
|
||||
|
||||
# gather in column groups
|
||||
local_state = gather_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
ParallelMode.PARALLEL_2D_COL,
|
||||
dims={weight_key: 0},
|
||||
partition_states={weight_key: True},
|
||||
keep_vars=keep_vars,
|
||||
)
|
||||
# gather in row groups
|
||||
if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0:
|
||||
local_state = gather_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
ParallelMode.PARALLEL_2D_ROW,
|
||||
dims={weight_key: -1},
|
||||
partition_states={weight_key: True},
|
||||
keep_vars=keep_vars,
|
||||
)
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
destination.update(local_state)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index)
|
||||
masked_input = input_.clone() - self.vocab_start_index
|
||||
@ -470,6 +860,7 @@ class Classifier2D(ParallelLayer):
|
||||
More details about ``initializer`` please refer to
|
||||
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_features: int,
|
||||
num_classes: int,
|
||||
@ -522,6 +913,93 @@ class Classifier2D(ParallelLayer):
|
||||
broadcast(self.bias, col_src_rank, ParallelMode.PARALLEL_2D_COL)
|
||||
broadcast(self.bias, row_src_rank, ParallelMode.PARALLEL_2D_ROW)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
# weight
|
||||
if self.has_weight:
|
||||
weight = state_dict.pop(weight_key, None)
|
||||
if weight is not None:
|
||||
local_state[weight_key] = weight
|
||||
# bias
|
||||
if self.bias is not None:
|
||||
bias = state_dict.pop(bias_key, None)
|
||||
if bias is not None:
|
||||
local_state[bias_key] = bias
|
||||
|
||||
# partition in row groups
|
||||
if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0:
|
||||
local_state = partition_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
ParallelMode.PARALLEL_2D_ROW,
|
||||
dims={
|
||||
weight_key: -1,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: False
|
||||
},
|
||||
)
|
||||
# partition in column groups
|
||||
local_state = partition_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
ParallelMode.PARALLEL_2D_COL,
|
||||
dims={
|
||||
weight_key: -1,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: False
|
||||
},
|
||||
)
|
||||
|
||||
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
local_state = OrderedDict()
|
||||
if self.has_weight:
|
||||
local_state[weight_key] = self.weight
|
||||
if self.bias is not None:
|
||||
local_state[bias_key] = self.bias
|
||||
|
||||
# gather in column groups
|
||||
local_state = gather_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
ParallelMode.PARALLEL_2D_COL,
|
||||
dims={
|
||||
weight_key: -1,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: False
|
||||
},
|
||||
keep_vars=keep_vars,
|
||||
)
|
||||
# gather in row groups
|
||||
if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0:
|
||||
local_state = gather_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
ParallelMode.PARALLEL_2D_ROW,
|
||||
dims={
|
||||
weight_key: -1,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: False
|
||||
},
|
||||
keep_vars=keep_vars,
|
||||
)
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
destination.update(local_state)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
out_shape = input_.shape[:-1] + (self.num_classes, )
|
||||
|
||||
@ -548,6 +1026,7 @@ class VocabParallelClassifier2D(ParallelLayer):
|
||||
More details about ``initializer`` please refer to
|
||||
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_features: int,
|
||||
num_classes: int,
|
||||
@ -605,6 +1084,94 @@ class VocabParallelClassifier2D(ParallelLayer):
|
||||
if self.bias is not None:
|
||||
bias_initializer(self.bias, fan_in=fan_in)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
# weight
|
||||
if self.has_weight:
|
||||
weight = state_dict.pop(weight_key, None)
|
||||
if weight is not None:
|
||||
local_state[weight_key] = weight
|
||||
# bias
|
||||
if self.bias is not None:
|
||||
bias = state_dict.pop(bias_key, None)
|
||||
if bias is not None:
|
||||
local_state[bias_key] = bias
|
||||
|
||||
# partition in row groups
|
||||
if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0:
|
||||
local_state = partition_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
ParallelMode.PARALLEL_2D_ROW,
|
||||
dims={
|
||||
weight_key: -1,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: True
|
||||
},
|
||||
)
|
||||
# partition in column groups
|
||||
local_state = partition_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
ParallelMode.PARALLEL_2D_COL,
|
||||
dims={
|
||||
weight_key: 0,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: True
|
||||
},
|
||||
)
|
||||
|
||||
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
local_state = OrderedDict()
|
||||
if self.has_weight:
|
||||
local_state[weight_key] = self.weight
|
||||
if self.bias is not None:
|
||||
local_state[bias_key] = self.bias
|
||||
|
||||
# gather in column groups
|
||||
local_state = gather_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
ParallelMode.PARALLEL_2D_COL,
|
||||
dims={
|
||||
weight_key: 0,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: True
|
||||
},
|
||||
keep_vars=keep_vars,
|
||||
)
|
||||
# gather in row groups
|
||||
if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0:
|
||||
local_state = gather_tensor_parallel_state_dict(
|
||||
local_state,
|
||||
ParallelMode.PARALLEL_2D_ROW,
|
||||
dims={
|
||||
weight_key: -1,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: True
|
||||
},
|
||||
keep_vars=keep_vars,
|
||||
)
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
local_state[weight_key] = local_state[weight_key].transpose(0, 1)
|
||||
destination.update(local_state)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
# input: [m/q, n/q, k/q]
|
||||
# output: [m/q, n/q, h/q]
|
||||
|
Loading…
Reference in New Issue
Block a user