From 7636d518e1033b741e7346d21fb551b55a4384ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E3=82=A2=E3=83=9E=E3=83=87=E3=82=A6=E3=82=B9?= Date: Fri, 1 Apr 2022 16:50:34 +0800 Subject: [PATCH] [model checkpoint] updated saving/loading for 2d layers (#595) --- colossalai/nn/layer/parallel_2d/layers.py | 583 +++++++++++++++++++++- 1 file changed, 575 insertions(+), 8 deletions(-) diff --git a/colossalai/nn/layer/parallel_2d/layers.py b/colossalai/nn/layer/parallel_2d/layers.py index 7226748d7..e2cc52801 100644 --- a/colossalai/nn/layer/parallel_2d/layers.py +++ b/colossalai/nn/layer/parallel_2d/layers.py @@ -1,4 +1,5 @@ import math +from collections import OrderedDict from typing import Callable import torch @@ -10,13 +11,15 @@ from colossalai.core import global_context as gpc from colossalai.global_variables import tensor_parallel_env as env from colossalai.nn import init as init from colossalai.registry import LAYERS +from colossalai.utils.checkpointing import gather_tensor_parallel_state_dict, partition_tensor_parallel_state_dict from colossalai.utils.cuda import get_current_device from torch import Tensor from torch.nn import Parameter from ..base_layer import ParallelLayer from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple -from ._operation import * +from ._operation import (Matmul_AB_2D, Matmul_ABT_2D, add_bias_2d, all_gather_tensor_2d, classifier_2d, layernorm_2d, + reduce_scatter_tensor_2d, split_tensor_2d) from ._utils import assert_summa_initialization, get_summa_dim_from_env @@ -39,6 +42,7 @@ class Linear2D(ParallelLayer): More details about ``initializer`` please refer to `init `_. """ + def __init__(self, in_features: int, out_features: int, @@ -90,6 +94,91 @@ class Linear2D(ParallelLayer): if self.bias is not None: bias_initializer(self.bias, fan_in=fan_in) + def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): + local_state = OrderedDict() + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight.transpose(0, 1) + # bias + if self.bias is not None: + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + + # partition in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_ROW, + dims={ + weight_key: -1, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }, + ) + # partition in column groups + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_COL, + dims={ + weight_key: 0, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }, + ) + + super()._load_from_state_dict(local_state, prefix, *args, **kwargs) + + def _save_to_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + local_state = OrderedDict({weight_key: self.weight}) + if self.bias is not None: + local_state[bias_key] = self.bias + + # gather in column groups + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_COL, + dims={ + weight_key: 0, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }, + keep_vars=keep_vars, + ) + # gather in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_ROW, + dims={ + weight_key: -1, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }, + keep_vars=keep_vars, + ) + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + local_state[weight_key] = local_state[weight_key].transpose(0, 1) + destination.update(local_state) + def forward(self, x: Tensor) -> Tensor: # input: [m/q, n/q, k/q] # output: [m/q, n/q, h/q] @@ -129,6 +218,7 @@ class LayerNorm2D(ParallelLayer): eps (float, optional): a value added to the denominator for numerical stability, defaults to 1e-05. dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. """ + def __init__(self, normalized_shape: int, eps: float = 1e-05, dtype=None): super().__init__() @@ -148,14 +238,95 @@ class LayerNorm2D(ParallelLayer): # create parameters factory_kwargs = {'device': get_current_device(), 'dtype': dtype} - self.gamma = Parameter(torch.ones(self.partitioned_partition, **factory_kwargs)) - self.beta = Parameter(torch.zeros(self.partitioned_partition, **factory_kwargs)) + self.weight = Parameter(torch.ones(self.partitioned_partition, **factory_kwargs)) + self.bias = Parameter(torch.zeros(self.partitioned_partition, **factory_kwargs)) self._set_tensor_parallel_attributes() def _set_tensor_parallel_attributes(self): - set_tensor_parallel_attribute_by_partition(self.gamma, self.summa_dim**2) - set_tensor_parallel_attribute_by_partition(self.beta, self.summa_dim**2) + set_tensor_parallel_attribute_by_partition(self.weight, self.summa_dim**2) + set_tensor_parallel_attribute_by_partition(self.bias, self.summa_dim**2) + + def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): + local_state = OrderedDict() + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + # bias + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + + # partition in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_ROW, + dims={ + weight_key: 0, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }, + ) + # partition in column groups + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_COL, + dims={ + weight_key: 0, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }, + ) + + super()._load_from_state_dict(local_state, prefix, *args, **kwargs) + + def _save_to_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + local_state = OrderedDict({weight_key: self.weight, bias_key: self.bias}) + + # gather in column groups + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_COL, + dims={ + weight_key: 0, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }, + keep_vars=keep_vars, + ) + # gather in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_ROW, + dims={ + weight_key: 0, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }, + keep_vars=keep_vars, + ) + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + destination.update(local_state) def forward(self, x: Tensor) -> Tensor: with torch.no_grad(): @@ -174,10 +345,10 @@ class LayerNorm2D(ParallelLayer): output = layernorm_2d(x, E_x, Var_x, self.normalized_shape, ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL) - bias = add_bias_2d(None, self.beta, self.partitioned_partition, self.row_rank, self.col_rank, + bias = add_bias_2d(None, self.bias, self.partitioned_partition, self.row_rank, self.col_rank, ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, True, self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, self.tensor_parallel_size) - scale = add_bias_2d(None, self.gamma, self.partitioned_partition, self.row_rank, self.col_rank, + scale = add_bias_2d(None, self.weight, self.partitioned_partition, self.row_rank, self.col_rank, ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, True, self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, self.tensor_parallel_size) output = torch.addcmul(bias, scale, output) @@ -205,6 +376,7 @@ class PatchEmbedding2D(ParallelLayer): More details about ``initializer`` please refer to `init `_. """ + def __init__(self, img_size: int, patch_size: int, @@ -260,6 +432,120 @@ class PatchEmbedding2D(ParallelLayer): bias_initializer(self.bias, fan_in=fan_in) position_embed_initializer(self.pos_embed) + def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): + local_state = OrderedDict() + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + cls_token_key = prefix + 'cls_token' + pos_embed_key = prefix + 'pos_embed' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + # bias + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + # cls token + cls_token = state_dict.pop(cls_token_key, None) + if cls_token is not None: + local_state[cls_token_key] = cls_token + # pos embed + pos_embed = state_dict.pop(pos_embed_key, None) + if pos_embed is not None: + local_state[pos_embed_key] = pos_embed + + # partition in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_ROW, + dims={ + weight_key: 0, + bias_key: 0, + cls_token_key: -1, + pos_embed_key: -1 + }, + partition_states={ + weight_key: True, + bias_key: True, + cls_token_key: True, + pos_embed_key: True + }, + ) + # partition in column groups + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_COL, + dims={ + weight_key: 0, + bias_key: 0, + cls_token_key: -1, + pos_embed_key: -1 + }, + partition_states={ + weight_key: True, + bias_key: True, + cls_token_key: True, + pos_embed_key: True + }, + ) + + super()._load_from_state_dict(local_state, prefix, *args, **kwargs) + + def _save_to_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + cls_token_key = prefix + 'cls_token' + pos_embed_key = prefix + 'pos_embed' + local_state = OrderedDict({ + weight_key: self.weight, + bias_key: self.bias, + cls_token_key: self.cls_token, + pos_embed_key: self.pos_embed + }) + + # gather in column groups + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_COL, + dims={ + weight_key: 0, + bias_key: 0, + cls_token_key: -1, + pos_embed_key: -1 + }, + partition_states={ + weight_key: True, + bias_key: True, + cls_token_key: True, + pos_embed_key: True + }, + keep_vars=keep_vars, + ) + # gather in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_ROW, + dims={ + weight_key: 0, + bias_key: 0, + cls_token_key: -1, + pos_embed_key: -1 + }, + partition_states={ + weight_key: True, + bias_key: True, + cls_token_key: True, + pos_embed_key: True + }, + keep_vars=keep_vars, + ) + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + destination.update(local_state) + def forward(self, input_: Tensor) -> Tensor: input_ = split_tensor_2d(input_) @@ -313,6 +599,7 @@ class Embedding2D(ParallelLayer): More details about initializer please refer to `init `_ """ + def __init__(self, num_embeddings: int, embedding_dim: int, @@ -353,6 +640,57 @@ class Embedding2D(ParallelLayer): with torch.no_grad(): self.weight[self.padding_idx].fill_(0) + def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): + local_state = OrderedDict() + weight_key = prefix + 'weight' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + + # partition in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_ROW, + dims={weight_key: -1}, + partition_states={weight_key: True}, + ) + # partition in column groups + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_COL, + dims={weight_key: -1}, + partition_states={weight_key: True}, + ) + + super()._load_from_state_dict(local_state, prefix, *args, **kwargs) + + def _save_to_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + 'weight' + local_state = OrderedDict({weight_key: self.weight}) + + # gather in column groups + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_COL, + dims={weight_key: -1}, + partition_states={weight_key: True}, + keep_vars=keep_vars, + ) + # gather in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_ROW, + dims={weight_key: -1}, + partition_states={weight_key: True}, + keep_vars=keep_vars, + ) + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + destination.update(local_state) + def forward(self, input_: Tensor) -> Tensor: input_ = split_tensor_2d(input_) @@ -392,6 +730,7 @@ class VocabParallelEmbedding2D(torch.nn.Module): More details about initializer please refer to `init `_. """ + def __init__(self, num_embeddings: int, embedding_dim: int, @@ -435,10 +774,61 @@ class VocabParallelEmbedding2D(torch.nn.Module): def _fill_padding_idx_with_zero(self) -> None: if self.padding_idx is not None and \ - self.padding_idx >= self.vocab_start_index and self.padding_idx < self.vocab_end_index: + self.padding_idx >= self.vocab_start_index and self.padding_idx < self.vocab_end_index: with torch.no_grad(): self.weight[self.padding_idx - self.vocab_start_index].fill_(0) + def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): + local_state = OrderedDict() + weight_key = prefix + 'weight' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + + # partition in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_ROW, + dims={weight_key: -1}, + partition_states={weight_key: True}, + ) + # partition in column groups + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_COL, + dims={weight_key: 0}, + partition_states={weight_key: True}, + ) + + super()._load_from_state_dict(local_state, prefix, *args, **kwargs) + + def _save_to_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + 'weight' + local_state = OrderedDict({weight_key: self.weight}) + + # gather in column groups + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_COL, + dims={weight_key: 0}, + partition_states={weight_key: True}, + keep_vars=keep_vars, + ) + # gather in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_ROW, + dims={weight_key: -1}, + partition_states={weight_key: True}, + keep_vars=keep_vars, + ) + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + destination.update(local_state) + def forward(self, input_: Tensor) -> Tensor: input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index) masked_input = input_.clone() - self.vocab_start_index @@ -470,6 +860,7 @@ class Classifier2D(ParallelLayer): More details about ``initializer`` please refer to `init `_. """ + def __init__(self, in_features: int, num_classes: int, @@ -522,6 +913,93 @@ class Classifier2D(ParallelLayer): broadcast(self.bias, col_src_rank, ParallelMode.PARALLEL_2D_COL) broadcast(self.bias, row_src_rank, ParallelMode.PARALLEL_2D_ROW) + def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): + local_state = OrderedDict() + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + if self.has_weight: + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + # bias + if self.bias is not None: + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + + # partition in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_ROW, + dims={ + weight_key: -1, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: False + }, + ) + # partition in column groups + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_COL, + dims={ + weight_key: -1, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: False + }, + ) + + super()._load_from_state_dict(local_state, prefix, *args, **kwargs) + + def _save_to_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + local_state = OrderedDict() + if self.has_weight: + local_state[weight_key] = self.weight + if self.bias is not None: + local_state[bias_key] = self.bias + + # gather in column groups + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_COL, + dims={ + weight_key: -1, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: False + }, + keep_vars=keep_vars, + ) + # gather in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_ROW, + dims={ + weight_key: -1, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: False + }, + keep_vars=keep_vars, + ) + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + destination.update(local_state) + def forward(self, input_: Tensor) -> Tensor: out_shape = input_.shape[:-1] + (self.num_classes, ) @@ -548,6 +1026,7 @@ class VocabParallelClassifier2D(ParallelLayer): More details about ``initializer`` please refer to `init `_. """ + def __init__(self, in_features: int, num_classes: int, @@ -605,6 +1084,94 @@ class VocabParallelClassifier2D(ParallelLayer): if self.bias is not None: bias_initializer(self.bias, fan_in=fan_in) + def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): + local_state = OrderedDict() + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + if self.has_weight: + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + # bias + if self.bias is not None: + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + + # partition in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_ROW, + dims={ + weight_key: -1, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }, + ) + # partition in column groups + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_COL, + dims={ + weight_key: 0, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }, + ) + + super()._load_from_state_dict(local_state, prefix, *args, **kwargs) + + def _save_to_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + local_state = OrderedDict() + if self.has_weight: + local_state[weight_key] = self.weight + if self.bias is not None: + local_state[bias_key] = self.bias + + # gather in column groups + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_COL, + dims={ + weight_key: 0, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }, + keep_vars=keep_vars, + ) + # gather in row groups + if gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_2D_ROW, + dims={ + weight_key: -1, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }, + keep_vars=keep_vars, + ) + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + local_state[weight_key] = local_state[weight_key].transpose(0, 1) + destination.update(local_state) + def forward(self, x: Tensor) -> Tensor: # input: [m/q, n/q, k/q] # output: [m/q, n/q, h/q]