diff --git a/colossalai/nn/layer/parallel_3d/layers.py b/colossalai/nn/layer/parallel_3d/layers.py index e8fb8afab..2be3b45d0 100644 --- a/colossalai/nn/layer/parallel_3d/layers.py +++ b/colossalai/nn/layer/parallel_3d/layers.py @@ -1,4 +1,5 @@ import math +from collections import OrderedDict from typing import Callable import torch @@ -12,13 +13,15 @@ from colossalai.global_variables import tensor_parallel_env as env from colossalai.nn import init as init from colossalai.nn.layer.base_layer import ParallelLayer from colossalai.registry import LAYERS +from colossalai.utils.checkpointing import (broadcast_state_dict, gather_tensor_parallel_state_dict, + partition_tensor_parallel_state_dict) from colossalai.utils.cuda import get_current_device from torch import Tensor from torch.nn import Parameter from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple -from ._operation import layernorm_3d, linear_3d, classifier_3d, split_tensor_3d -from ._operation import all_gather_tensor_3d, reduce_scatter_tensor_3d, broadcast_weight_3d_from_diagonal +from ._operation import (all_gather_tensor_3d, broadcast_weight_3d_from_diagonal, classifier_3d, layernorm_3d, + linear_3d, reduce_scatter_tensor_3d, split_tensor_3d) from ._utils import get_depth_from_env, get_last_group, get_parallel_mode_from_env, swap_in_out_group @@ -61,6 +64,67 @@ class LayerNorm3D(ParallelLayer): init.zeros_()(self.bias) init.ones_()(self.weight) + def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): + local_state = OrderedDict() + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight.transpose(0, 1) + # bias + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + + # partition in output groups + if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ + gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + self.output_parallel_mode, + dims={ + weight_key: 0, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True, + }, + ) + # broadcast in input groups + if gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = broadcast_state_dict(local_state, self.input_parallel_mode) + # broadcast in weight groups + local_state = broadcast_state_dict(local_state, self.weight_parallel_mode) + + super()._load_from_state_dict(local_state, prefix, *args, **kwargs) + + def _save_to_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + local_state = OrderedDict({weight_key: self.weight, bias_key: self.bias}) + + # gather in output groups + if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ + gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + self.output_parallel_mode, + dims={ + weight_key: 0, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }, + keep_vars=keep_vars, + ) + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + destination.update(local_state) + def forward(self, input_: Tensor) -> Tensor: return layernorm_3d(input_, self.weight, self.bias, self.normalized_shape, self.variance_epsilon, self.input_parallel_mode, self.weight_parallel_mode, self.output_parallel_mode) @@ -135,6 +199,122 @@ class Linear3D(ParallelLayer): broadcast(self.bias, weight_src_rank, self.weight_parallel_mode) broadcast(self.bias, output_src_rank, self.output_parallel_mode) + def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): + local_state = OrderedDict() + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight.transpose(0, 1) + # bias + if self.bias is not None: + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + + # partition in output groups + if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ + gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + self.output_parallel_mode, + dims={ + weight_key: 0, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: False + }, + ) + # partition in input groups + if gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + self.input_parallel_mode, + dims={ + weight_key: -1, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }, + ) + # partition in weight groups + local_state = partition_tensor_parallel_state_dict( + local_state, + self.weight_parallel_mode, + dims={ + weight_key: -1, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: False + }, + ) + + super()._load_from_state_dict(local_state, prefix, *args, **kwargs) + + def _save_to_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + local_state = OrderedDict({weight_key: self.weight}) + if self.bias is not None: + local_state[bias_key] = self.bias + + # gather in weight groups + local_state = gather_tensor_parallel_state_dict( + local_state, + self.weight_parallel_mode, + dims={ + weight_key: -1, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: False + }, + keep_vars=keep_vars, + ) + # gather in input groups + if gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + self.input_parallel_mode, + dims={ + weight_key: -1, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }, + keep_vars=keep_vars, + ) + # gather in output groups + if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ + gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + self.output_parallel_mode, + dims={ + weight_key: 0, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: False + }, + keep_vars=keep_vars, + ) + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + local_state[weight_key] = local_state[weight_key].transpose(0, 1) + destination.update(local_state) + def forward(self, input_: Tensor) -> Tensor: return linear_3d(input_, self.weight, self.bias, self.input_parallel_mode, self.weight_parallel_mode, self.output_parallel_mode) @@ -212,6 +392,73 @@ class Classifier3D(ParallelLayer): broadcast(self.bias, output_src_rank, self.output_parallel_mode) broadcast(self.bias, input_src_rank, self.input_parallel_mode) + def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): + local_state = OrderedDict() + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + if self.has_weight: + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + # bias + if self.bias is not None: + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + + # partition in output groups + if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ + gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + self.output_parallel_mode, + dims={ + weight_key: -1, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: False + }, + ) + # broadcast in input groups + if gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = broadcast_state_dict(local_state, self.input_parallel_mode) + # broadcast in weight groups + local_state = broadcast_state_dict(local_state, self.weight_parallel_mode) + + super()._load_from_state_dict(local_state, prefix, *args, **kwargs) + + def _save_to_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + local_state = OrderedDict() + if self.has_weight: + local_state[weight_key] = self.weight + if self.bias is not None: + local_state[bias_key] = self.bias + + # gather in output groups + if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ + gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + self.output_parallel_mode, + dims={ + weight_key: -1, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: False + }, + keep_vars=keep_vars, + ) + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + destination.update(local_state) + def forward(self, input_: Tensor) -> Tensor: return classifier_3d(input_, self.weight, self.bias, self.input_parallel_mode, self.weight_parallel_mode, self.output_parallel_mode) @@ -296,6 +543,122 @@ class VocabParallelClassifier3D(ParallelLayer): broadcast(self.bias, weight_src_rank, self.weight_parallel_mode) broadcast(self.bias, output_src_rank, self.output_parallel_mode) + def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): + local_state = OrderedDict() + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + if self.has_weight: + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + # bias + if self.bias is not None: + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + + # partition in output groups + if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ + gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + self.output_parallel_mode, + dims={ + weight_key: -1, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: False + }, + ) + # partition in input groups + if gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + self.input_parallel_mode, + dims={ + weight_key: 0, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }, + ) + # partition in weight groups + local_state = partition_tensor_parallel_state_dict( + local_state, + self.weight_parallel_mode, + dims={ + weight_key: 0, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: False + }, + ) + + super()._load_from_state_dict(local_state, prefix, *args, **kwargs) + + def _save_to_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + local_state = OrderedDict({weight_key: self.weight}) + if self.bias is not None: + local_state[bias_key] = self.bias + + # gather in weight groups + local_state = gather_tensor_parallel_state_dict( + local_state, + self.weight_parallel_mode, + dims={ + weight_key: 0, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: False + }, + keep_vars=keep_vars, + ) + # gather in input groups + if gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + self.input_parallel_mode, + dims={ + weight_key: 0, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }, + keep_vars=keep_vars, + ) + # gather in output groups + if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ + gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + self.output_parallel_mode, + dims={ + weight_key: -1, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: False + }, + keep_vars=keep_vars, + ) + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + destination.update(local_state) + def forward(self, input_: Tensor) -> Tensor: return linear_3d(input_, self.weight.transpose(0, 1), self.bias, self.input_parallel_mode, self.weight_parallel_mode, self.output_parallel_mode) @@ -392,12 +755,98 @@ class PatchEmbedding3D(ParallelLayer): self.cls_token.register_hook(self._sync_grad_hook) self.pos_embed.register_hook(self._sync_grad_hook) + def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): + local_state = OrderedDict() + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + cls_token_key = prefix + 'cls_token' + pos_embed_key = prefix + 'pos_embed' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + # bias + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + # cls token + cls_token = state_dict.pop(cls_token_key, None) + if cls_token is not None: + local_state[cls_token_key] = cls_token + # pos embed + pos_embed = state_dict.pop(pos_embed_key, None) + if pos_embed is not None: + local_state[pos_embed_key] = pos_embed + + # partition in output groups + if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ + gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + self.output_parallel_mode, + dims={ + weight_key: 0, + bias_key: 0, + cls_token_key: -1, + pos_embed_key: -1 + }, + partition_states={ + weight_key: True, + bias_key: True, + cls_token_key: True, + pos_embed_key: True + }, + ) + # broadcast in input groups + if gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = broadcast_state_dict(local_state, self.input_parallel_mode) + # broadcast in weight groups + local_state = broadcast_state_dict(local_state, self.weight_parallel_mode) + + super()._load_from_state_dict(local_state, prefix, *args, **kwargs) + + def _save_to_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + cls_token_key = prefix + 'cls_token' + pos_embed_key = prefix + 'pos_embed' + local_state = OrderedDict({ + weight_key: self.weight, + bias_key: self.bias, + cls_token_key: self.cls_token, + pos_embed_key: self.pos_embed + }) + + # gather in output groups + if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ + gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + self.output_parallel_mode, + dims={ + weight_key: 0, + bias_key: 0, + cls_token_key: -1, + pos_embed_key: -1 + }, + partition_states={ + weight_key: True, + bias_key: True, + cls_token_key: True, + pos_embed_key: True + }, + keep_vars=keep_vars, + ) + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + destination.update(local_state) + def forward(self, input_: Tensor) -> Tensor: input_ = split_tensor_3d(input_, 0, self.weight_parallel_mode) input_ = split_tensor_3d(input_, 0, self.input_parallel_mode) output = F.conv2d(input_, self.weight, self.bias, stride=self.patch_size) if self.flatten: - output = output.flatten(2).transpose(1, 2) # BCHW -> BNC + output = output.flatten(2).transpose(1, 2) # BCHW -> BNC cls_token = self.cls_token.expand(output.shape[0], -1, -1) output = torch.cat((cls_token, output), dim=1) @@ -480,6 +929,49 @@ class Embedding3D(ParallelLayer): with torch.no_grad(): self.weight[self.padding_idx].fill_(0) + def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): + local_state = OrderedDict() + weight_key = prefix + 'weight' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + + # partition in output groups + if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ + gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + self.output_parallel_mode, + dims={weight_key: 0}, + partition_states={weight_key: True}, + ) + # broadcast in input groups + if gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = broadcast_state_dict(local_state, self.input_parallel_mode) + # broadcast in weight groups + local_state = broadcast_state_dict(local_state, self.weight_parallel_mode) + + super()._load_from_state_dict(local_state, prefix, *args, **kwargs) + + def _save_to_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + 'weight' + local_state = OrderedDict({weight_key: self.weight}) + + # gather in output groups + if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ + gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + self.output_parallel_mode, + dims={weight_key: 0}, + partition_states={weight_key: True}, + keep_vars=keep_vars, + ) + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + destination.update(local_state) + def forward(self, input_: Tensor) -> Tensor: input_ = split_tensor_3d(input_, 0, self.weight_parallel_mode) input_ = split_tensor_3d(input_, 0, self.input_parallel_mode) @@ -570,6 +1062,76 @@ class VocabParallelEmbedding3D(torch.nn.Module): with torch.no_grad(): self.weight[self.padding_idx - self.vocab_start_index].fill_(0) + def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): + local_state = OrderedDict() + weight_key = prefix + 'weight' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + + # partition in output groups + if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ + gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + self.output_parallel_mode, + dims={weight_key: -1}, + partition_states={weight_key: True}, + ) + # partition in input groups + if gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = partition_tensor_parallel_state_dict( + local_state, + self.input_parallel_mode, + dims={weight_key: 0}, + partition_states={weight_key: True}, + ) + # partition in weight groups + local_state = partition_tensor_parallel_state_dict( + local_state, + self.weight_parallel_mode, + dims={weight_key: 0}, + partition_states={weight_key: True}, + ) + + super()._load_from_state_dict(local_state, prefix, *args, **kwargs) + + def _save_to_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + 'weight' + local_state = OrderedDict({weight_key: self.weight}) + + # gather in weight groups + local_state = gather_tensor_parallel_state_dict( + local_state, + self.weight_parallel_mode, + dims={weight_key: 0}, + partition_states={weight_key: True}, + keep_vars=keep_vars, + ) + # gather in input groups + if gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + self.input_parallel_mode, + dims={weight_key: 0}, + partition_states={weight_key: True}, + keep_vars=keep_vars, + ) + # gather in output groups + if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ + gpc.get_local_rank(self.weight_parallel_mode) == 0: + local_state = gather_tensor_parallel_state_dict( + local_state, + self.output_parallel_mode, + dims={weight_key: -1}, + partition_states={weight_key: True}, + keep_vars=keep_vars, + ) + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + destination.update(local_state) + def forward(self, input_: Tensor) -> Tensor: input_ = split_tensor_3d(input_, 0, self.weight_parallel_mode)