diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index 6c5033773..442ac4a8d 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -44,10 +44,10 @@ ZERO_AXIS, DP_AXIS, TP_AXIS = 0, 1, 2 def get_param_info(optim: Optimizer): # Get a backup of necessary information of parameters for future use, which includes: # 1. A mapping from integer param_id to param32 shape. - if optim is None: return {} param_info = {"id2shape": {}} + start_index = 0 for group in optim.param_groups: for param_id, param in enumerate(group["params"], start_index): @@ -527,7 +527,7 @@ class GeminiPlugin(DPPluginBase): dataloader: Optional[DataLoader] = None, lr_scheduler: Optional[LRScheduler] = None, ) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: - optimizer_params_info = get_param_info(optimizer) + params_info = get_param_info(optimizer) if not isinstance(model, ModelWrapper): # convert model to sync bn # FIXME(ver217): gemini does not support sync bn @@ -558,7 +558,7 @@ class GeminiPlugin(DPPluginBase): **self.zero_optim_config, **self.optim_kwargs, tp_group=self.tp_group, - optimizer_params_info=optimizer_params_info, + params_info=params_info, verbose=self.verbose, ) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 29cec7cfd..8d12eb806 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -213,12 +213,7 @@ def get_param_info(optim: Optimizer): if optim is None: return {} - param_info = { - "param_groups": [], - "param2id": {}, - "id2param": {}, - "param2shape": {}, - } + param_info = {"param_groups": [], "param2id": {}, "id2param": {}, "param2shape": {}} start_index = 0 for group in optim.param_groups: packed_group = {k: v for k, v in group.items() if k != "params"} @@ -947,6 +942,8 @@ class HybridParallelPlugin(PipelinePluginBase): num_model_chunks (int, optional): The number of model chunks for interleaved pipeline parallelism. Defaults to 1. gradient_checkpoint_config (GradientCheckpointConfig, optional): Configuration for gradient checkpointing. Defaults to None. enable_metadata_cache (bool, optional): Whether to enable metadata cache for pipeline parallelism. Defaults to True. + make_vocab_size_divisible_by (int, optional): it's used when padding the vocabulary size, to make it choose an faster kenel. Default to 64. + """ def __init__( @@ -989,6 +986,7 @@ class HybridParallelPlugin(PipelinePluginBase): num_model_chunks: int = 1, gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None, enable_metadata_cache: bool = True, + make_vocab_size_divisible_by: int = 64, ) -> None: super().__init__() assert ( @@ -1095,6 +1093,7 @@ class HybridParallelPlugin(PipelinePluginBase): sequence_parallelism_mode=sequence_parallelism_mode, enable_sequence_overlap=enable_sequence_overlap, parallel_output=parallel_output, + make_vocab_size_divisible_by=make_vocab_size_divisible_by, gradient_checkpoint_config=gradient_checkpoint_config, ) self.amp_config = dict( diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index 808227249..7946d9b9c 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -14,6 +14,12 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from colossalai.cluster import DistCoordinator from colossalai.interface import ModelWrapper, OptimizerWrapper +from colossalai.tensor.padded_tensor import ( + init_as_padded_tensor, + is_padded_tensor, + to_padded_tensor, + to_unpadded_tensor, +) from colossalai.utils import get_current_device from .general_checkpoint_io import GeneralCheckpointIO @@ -32,6 +38,7 @@ from .utils import ( save_param_groups, save_state_dict, save_state_dict_shards, + search_padding_dim, search_tp_partition_dim, sharded_optimizer_loading_epilogue, ) @@ -89,6 +96,8 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): if param is None: continue # Gather tensor pieces when using tensor parallel. + if is_padded_tensor(param): + param = to_unpadded_tensor(param) param_ = gather_distributed_param(param, keep_vars=False) block, block_size = state_dict_sharder.append_param(prefix + name, param_) if block is not None: @@ -231,7 +240,6 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): # When pipeline is used, each stage produces its own shard files and index files. # Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/ # After all the state_dicts have been saved, the master rank integrates all the index files into one final index file and deletes the tmp folder. - final_index_file_path = copy.deepcopy(save_index_file) tmp_index_file_folder = os.path.join(checkpoint, "tmp_index_files") Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True) @@ -251,6 +259,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): use_safetensors=use_safetensors, use_pp_format=True, ) + if control_saving: assert ( self.dp_rank == 0 and self.tp_rank == 0 @@ -867,6 +876,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): dist.all_gather(gather_tensor, v, group=tp_group) v = torch.cat(gather_tensor, dim=partition_dim) + padding_dim = search_padding_dim(v.shape, original_shape) + if padding_dim is not None: + v = init_as_padded_tensor(v, v.shape[padding_dim], original_shape[padding_dim], padding_dim) + v = to_unpadded_tensor(v) + state_[k] = v.detach().clone().to(device) return state_ @@ -899,6 +913,19 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): if isinstance(v, torch.Tensor) and k != "step": # Shard state along tensor parallel group. partition_dim = search_tp_partition_dim(current_shape, original_shape, self.tp_size) + global_shape = current_shape + if partition_dim is not None: + # pad embedding params + global_shape = ( + *current_shape[:partition_dim], + current_shape[partition_dim] * self.tp_size, + *current_shape[partition_dim + 1 :], + ) + + padding_dim = search_padding_dim(global_shape, original_shape) + if padding_dim is not None: + v = to_padded_tensor(v, global_shape[padding_dim], padding_dim) + if partition_dim is not None: slice_size = current_shape[partition_dim] v = v.split(slice_size, dim=partition_dim)[self.tp_rank] diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index 2a1d4de9b..6197be9d1 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -120,6 +120,15 @@ def search_tp_partition_dim(current_shape: torch.Size, original_shape: torch.Siz return partition_dim +def search_padding_dim(global_shape: torch.Size, original_shape: torch.Size) -> Optional[int]: + padding_dim = None + for dim, length in enumerate(global_shape): + if length > original_shape[dim]: + padding_dim = dim + break + return padding_dim + + # ====================================== # Helper classes and functions for saving shard file # ====================================== diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py index 7b8aa5380..f17fad1b6 100644 --- a/colossalai/shardformer/layer/__init__.py +++ b/colossalai/shardformer/layer/__init__.py @@ -1,8 +1,8 @@ from ._operation import all_to_all_comm from .attn import AttnMaskType, ColoAttention from .dropout import DropoutForParallelInput, DropoutForReplicatedInput -from .embedding import Embedding1D, VocabParallelEmbedding1D -from .linear import Linear1D_Col, Linear1D_Row +from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D +from .linear import Linear1D_Col, Linear1D_Row, PaddingLMHead, VocabParallelLMHead1D from .loss import cross_entropy_1d from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm from .parallel_module import ParallelModule @@ -25,6 +25,9 @@ __all__ = [ "FusedRMSNorm", "FusedLinear1D_Col", "ParallelModule", + "PaddingEmbedding", + "PaddingLMHead", + "VocabParallelLMHead1D", "AttnMaskType", "ColoAttention", "all_to_all_comm", diff --git a/colossalai/shardformer/layer/embedding.py b/colossalai/shardformer/layer/embedding.py index d081b2040..cb7eceae4 100644 --- a/colossalai/shardformer/layer/embedding.py +++ b/colossalai/shardformer/layer/embedding.py @@ -21,10 +21,10 @@ from colossalai.tensor.d_tensor.api import ( ) from ._operation import gather_forward_split_backward, reduce_forward -from .parallel_module import ParallelModule +from .parallel_module import PaddingParallelModule, ParallelModule from .utils import create_randomizer_with_offset -__all__ = ["Embedding1D", "VocabParallelEmbedding1D"] +__all__ = ["Embedding1D", "VocabParallelEmbedding1D", "PaddingEmbedding"] class Embedding1D(ParallelModule): @@ -161,7 +161,80 @@ class Embedding1D(ParallelModule): return output_parallel -class VocabParallelEmbedding1D(ParallelModule): +class PaddingEmbedding(PaddingParallelModule): + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int = None, + dtype: torch.dtype = None, + device: torch.device = None, + weight: Optional[nn.Parameter] = None, + make_vocab_size_divisible_by: int = 64, + *args, + **kwargs, + ): + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + self.embed_args = args + self.embed_kwargs = kwargs + self.padding_idx = padding_idx + if num_embeddings % make_vocab_size_divisible_by != 0: + self.num_embeddings = ( + num_embeddings + make_vocab_size_divisible_by - (num_embeddings % make_vocab_size_divisible_by) + ) + # create weight and bias + if weight is None: + factory_kwargs = {"device": device, "dtype": dtype} + weight = nn.Parameter(torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs)) + else: + weight.data = weight.data.to(device=device, dtype=dtype) + + super().__init__(self.num_embeddings, num_embeddings, weight) + + if weight is None: + self.reset_parameters() + + def reset_parameters(self) -> None: + init.normal_(self.weight) + self._fill_padding_idx_with_zero() + + def _fill_padding_idx_with_zero(self) -> None: + if self.padding_idx is not None: + with torch.no_grad(): + self.weight[self.padding_idx].fill_(0) + + def forward(self, input: Tensor) -> Tensor: + return F.embedding(input, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) + + @staticmethod + def from_native_module( + module: nn.Embedding, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs + ) -> PaddingParallelModule: + r""" + Convert a native pytorch embedding module to a parallel module. + """ + LazyInitContext.materialize(module) + # get the origin attributes + num_embeddings = module.num_embeddings + embedding_dim = module.embedding_dim + padding_idx = module.padding_idx + device = module.weight.device + # create the parallel module + padding_embedding = PaddingEmbedding( + num_embeddings=num_embeddings, + embedding_dim=embedding_dim, + padding_idx=padding_idx, + device=device, + weight=module.weight, + *args, + **kwargs, + ) + + return padding_embedding + + +class VocabParallelEmbedding1D(PaddingParallelModule): r"""Embedding parallelized in the vocabulary dimension. Args: @@ -201,10 +274,10 @@ class VocabParallelEmbedding1D(ParallelModule): process_group: ProcessGroup = None, weight: Optional[nn.Parameter] = None, weight_initializer: Callable = init.normal_(), + make_vocab_size_divisible_by: int = 64, *args, **kwargs, ): - super().__init__() self.num_embeddings = num_embeddings self.embedding_dim = embedding_dim self.embed_args = args @@ -214,8 +287,23 @@ class VocabParallelEmbedding1D(ParallelModule): tensor_parallel_size = dist.get_world_size(group=process_group) tensor_parallel_rank = dist.get_rank(group=process_group) - self.num_embeddings_per_partition = divide(num_embeddings, tensor_parallel_size) - self.num_embeddings = self.num_embeddings_per_partition + # generate weight and bias + if weight is None: + factory_kwargs = {"device": device, "dtype": dtype} + weight = nn.Parameter(torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs)) + else: + weight.data = weight.data.to(device=device, dtype=dtype) + + # calculate new padding size + multiple = make_vocab_size_divisible_by * tensor_parallel_size + if num_embeddings % multiple != 0: + self.num_embeddings = num_embeddings + multiple - (num_embeddings % multiple) + + # resize vocabulary size + super().__init__(self.num_embeddings, num_embeddings, weight) + + # deal with tensor parallelism + self.num_embeddings_per_partition = divide(self.num_embeddings, tensor_parallel_size) self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition @@ -226,13 +314,6 @@ class VocabParallelEmbedding1D(ParallelModule): seed = torch.random.initial_seed() self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) - # parameter - if weight is None: - factory_kwargs = {"device": device, "dtype": dtype} - self.weight = nn.Parameter(torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs)) - else: - weight.data = weight.data.to(device=device, dtype=dtype) - self.weight = weight if not is_distributed_tensor(self.weight): sharded_weight = shard_rowwise(self.weight.data, process_group) sharded_tensor_to_existing_param(sharded_weight, self.weight) @@ -243,7 +324,7 @@ class VocabParallelEmbedding1D(ParallelModule): @staticmethod def from_native_module( module: nn.Embedding, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs - ) -> ParallelModule: + ) -> PaddingParallelModule: r""" Convert a native pytorch embedding module to a parallel module. """ @@ -303,11 +384,9 @@ class VocabParallelEmbedding1D(ParallelModule): # Mask the input. masked_input = input_.clone() - self.vocab_start_index masked_input[input_mask] = 0 - output_parallel = F.embedding( masked_input, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs ) - # Mask the output embedding. embedding_output = output_parallel.clone() embedding_output[input_mask, :] = 0.0 diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py index 7c8619ad8..37c754241 100644 --- a/colossalai/shardformer/layer/linear.py +++ b/colossalai/shardformer/layer/linear.py @@ -32,7 +32,7 @@ from ._operation import ( reducescatter_forward_gather_backward, split_forward_gather_backward, ) -from .parallel_module import ParallelModule +from .parallel_module import PaddingParallelModule, ParallelModule from .utils import create_randomizer_with_offset __all__ = ["Linear1D_Col", "Linear1D_Row"] @@ -84,8 +84,9 @@ class Linear1D_Col(ParallelModule): bias_: Optional[Parameter] = None, weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + **kwargs, ): - super().__init__() + super().__init__(weight=weight, bias_=bias_, **kwargs) # Keep input parameters self.in_features = in_features @@ -118,6 +119,7 @@ class Linear1D_Col(ParallelModule): else: weight.data = weight.data.to(device=device, dtype=dtype) self.weight = weight + if not is_distributed_tensor(self.weight): sharded_weight = shard_rowwise(self.weight.data, self.process_group) sharded_tensor_to_existing_param(sharded_weight, self.weight) @@ -140,7 +142,7 @@ class Linear1D_Col(ParallelModule): @staticmethod def from_native_module( - module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs + module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], **kwargs ) -> ParallelModule: r""" Convert a native PyTorch linear layer to a parallelized linear layer. @@ -173,7 +175,6 @@ class Linear1D_Col(ParallelModule): process_group=process_group, weight=module.weight, bias_=module.bias, - *args, **kwargs, ) @@ -322,7 +323,7 @@ class Linear1D_Row(ParallelModule): @staticmethod def from_native_module( - module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs + module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], **kwargs ) -> ParallelModule: r""" Convert a native PyTorch linear layer to a parallelized linear layer. @@ -356,7 +357,6 @@ class Linear1D_Row(ParallelModule): process_group=process_group, weight=module.weight, bias_=module.bias, - *args, **kwargs, ) @@ -439,3 +439,211 @@ class Linear1D_Row(ParallelModule): return output else: return output, self.bias + + +class PaddingLMHead(PaddingParallelModule): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + device: torch.device = None, + weight: Optional[Parameter] = None, + bias_: Optional[Parameter] = None, + make_vocab_size_divisible_by: int = 64, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + ): + # Keep input parameters + self.in_features = in_features + self.out_features = out_features + + if out_features % make_vocab_size_divisible_by != 0: + self.out_features = ( + out_features + make_vocab_size_divisible_by - (out_features % make_vocab_size_divisible_by) + ) + if weight is None: + factory_kwargs = {"device": device, "dtype": dtype} + weight = Parameter(torch.empty(out_features, self.in_features, **factory_kwargs)) + else: + weight.data = weight.data.to(device=device, dtype=dtype) + + if bias: + if bias_ is None: + self.bias = Parameter(torch.empty(out_features, **factory_kwargs)) + else: + bias_.data = bias_.data.to(device=device, dtype=dtype) + else: + bias_ = None + + # resize embeddings + super().__init__(self.out_features, out_features, weight, bias_) + + if weight is None: + self.reset_parameters(weight_initializer, bias_initializer) + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + fan_in, fan_out = self.in_features, self.out_features + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + + @staticmethod + def from_native_module( + module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], **kwargs + ) -> PaddingParallelModule: + r""" + Convert a native PyTorch linear layer to a parallelized linear layer. + """ + LazyInitContext.materialize(module) + # get the attributes + in_features = module.in_features + out_features = module.out_features + bias = module.bias is not None + device = module.weight.device + # ensure only one process group is passed + + lm_head_linear = PaddingLMHead( + in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + weight=module.weight, + bias_=module.bias, + **kwargs, + ) + + return lm_head_linear + + def forward(self, input: Tensor) -> Tensor: + output = F.linear(input, self.weight, self.bias) + output = output[..., : self.old_num_embeddings] + return output + + +class VocabParallelLMHead1D(Linear1D_Col, PaddingParallelModule): + r"""Linear layer with column parallelism. + + The linear layer is defined as :math:`Y = XA + b`. A is parallelized along + its second dimension as :math:`A = [A_1, ..., A_p]`. + + Args: + in_features (int): size of each input sample. + out_features (int): size of each output sample. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (`torch.dtype`): The dtype of parameters, defaults to None. + device (`torch.device`): The device of parameters, defaults to None. + process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None. + gather_output (bool, optional): If true, call all-gather on output and make Y available + to all GPUs, otherwise, every GPU will have its output + which is :math:`Y_i = XA_i`, defaults to False + seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False. + overlap (`bool`): If set to ``True``, it will overlap input all-gather with gradient computation during backward, defaults to False. + skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, + which is preserved for kernel fusion, defaults to False + weight_initializer (`typing.Callable`): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (`typing.Callable`): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + device: torch.device = None, + process_group: ProcessGroup = None, + weight: Optional[Parameter] = None, + bias_: Optional[Parameter] = None, + make_vocab_size_divisible_by: int = 64, + **kwargs, + ): + # create weight and bias + if weight is None: + factory_kwargs = {"device": device, "dtype": dtype} + weight = Parameter(torch.empty(out_features, self.in_features, **factory_kwargs)) + if bias: + if bias_ is None: + bias_ = Parameter(torch.empty(out_features, **factory_kwargs)) + else: + bias_ = None + + # calculate new vocab size + self.tensor_parallel_size = dist.get_world_size(group=process_group) + new_out_features = out_features + multiple = make_vocab_size_divisible_by * self.tensor_parallel_size + if out_features % multiple != 0: + new_out_features = out_features + multiple - (out_features % multiple) + + super().__init__( + in_features=in_features, + out_features=new_out_features, + bias=bias, + device=device, + process_group=process_group, + weight=weight, + bias_=bias_, + **kwargs, + new_num_embeddings=new_out_features, + old_num_embeddings=out_features, + ) + # get the length of valid embeddings + tp_rank = dist.get_rank(process_group) + partition_size = self.new_num_embeddings // dist.get_world_size(process_group) + if self.old_num_embeddings >= (tp_rank + 1) * partition_size: + self.num_valid_embeddings_local = partition_size + elif self.old_num_embeddings >= tp_rank * partition_size: + self.num_valid_embeddings_local = self.old_num_embeddings - tp_rank * partition_size + else: + self.num_valid_embeddings_local = 0 + + @staticmethod + def from_native_module( + module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], **kwargs + ) -> PaddingParallelModule: + r""" + Convert a native PyTorch linear layer to a parallelized linear layer. + """ + LazyInitContext.materialize(module) + # get the attributes + in_features = module.in_features + out_features = module.out_features + bias = module.bias is not None + device = module.weight.device + + lm_head_linear = VocabParallelLMHead1D( + in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + process_group=process_group, + weight=module.weight, + bias_=module.bias, + **kwargs, + ) + + return lm_head_linear + + def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: + # get forward output + if self.skip_bias_add: + output, bias = super().forward(input_) + else: + output = super().forward(input_) + + # delete the padding of output + if self.gather_output: + output = output[..., : self.old_num_embeddings] + else: + output = output[..., : self.num_valid_embeddings_local] + + # return + if self.skip_bias_add: + return output, bias + return output diff --git a/colossalai/shardformer/layer/loss.py b/colossalai/shardformer/layer/loss.py index c4cf3fb85..6d99efc19 100644 --- a/colossalai/shardformer/layer/loss.py +++ b/colossalai/shardformer/layer/loss.py @@ -15,7 +15,14 @@ class DistCrossEntropy(Function): """ @staticmethod - def forward(ctx, vocab_logits: torch.Tensor, target: torch.Tensor, ignore_index: int, process_group: ProcessGroup): + def forward( + ctx, + vocab_logits: torch.Tensor, + target: torch.Tensor, + ignore_index: int, + process_group: ProcessGroup, + vocab_size: int, + ): r""" Calculate the cross entropy loss before gather, the origin loss function is as follows: loss = -log(exp(x[class])/sum(exp(x[i])) @@ -41,15 +48,21 @@ class DistCrossEntropy(Function): vocab_logits = vocab_logits - logits_max.unsqueeze(dim=-1) # mask the target in the local device - partition_vocab_size = vocab_logits.size()[-1] rank = dist.get_rank(group=process_group) world_size = dist.get_world_size(group=process_group) - global_vocab_size = partition_vocab_size * world_size + if vocab_size == None: + partition_vocab_size = vocab_logits.size()[-1] + global_vocab_size = partition_vocab_size * world_size + else: + global_vocab_size = vocab_size + partition_vocab_size = global_vocab_size // world_size # [down, up) => false, other device and -100 => true delta = (global_vocab_size + world_size - 1) // world_size down_threshold = rank * delta up_threshold = down_threshold + delta + if up_threshold > global_vocab_size: + up_threshold = global_vocab_size mask = (target < down_threshold) | (target >= up_threshold) masked_target = target.clone() - down_threshold masked_target[mask] = 0 @@ -57,7 +70,8 @@ class DistCrossEntropy(Function): # reshape the logits and target # reshape the vocab_logits to [bath_size * seq_len, vocab_size] # reshape the labels to [bath_size * seq_len] - logits_2d = vocab_logits.view(-1, partition_vocab_size) + self_vocab_size = vocab_logits.size()[-1] + logits_2d = vocab_logits.view(-1, self_vocab_size) masked_target_1d = masked_target.view(-1) # extract the x[class] and set the x[other device] to zero @@ -104,10 +118,14 @@ class DistCrossEntropy(Function): grad_logits_2d[torch.arange(0, grad_logits_2d.shape[0]), masked_target_1d] -= update grad_logits.mul_(grad_output.unsqueeze(dim=-1)) - return grad_logits, None, None, None + return grad_logits, None, None, None, None def cross_entropy_1d( - vocab_logits: torch.Tensor, labels: torch.Tensor, ignore_index: int = -100, process_group: ProcessGroup = None + vocab_logits: torch.Tensor, + labels: torch.Tensor, + ignore_index: int = -100, + process_group: ProcessGroup = None, + vocab_size: int = None, ) -> torch.Tensor: - return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group) + return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group, vocab_size) diff --git a/colossalai/shardformer/layer/parallel_module.py b/colossalai/shardformer/layer/parallel_module.py index 6c0d83cc7..11ef73538 100644 --- a/colossalai/shardformer/layer/parallel_module.py +++ b/colossalai/shardformer/layer/parallel_module.py @@ -3,7 +3,7 @@ import itertools from abc import ABC, abstractmethod -from typing import List, Union +from typing import List, Optional, Union import torch import torch.nn as nn @@ -20,11 +20,15 @@ from colossalai.tensor.d_tensor import ( is_distributed_tensor, sharded_tensor_to_param, ) +from colossalai.tensor.padded_tensor import is_padded_tensor, to_padded_tensor, to_unpadded_tensor __all__ = ["ParallelModule"] class ParallelModule(nn.Module, ABC): + def __init__(self, **kwargs): + super().__init__() + @abstractmethod def from_native_module( module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]] = None @@ -54,7 +58,7 @@ class ParallelModule(nn.Module, ABC): """ for name, param in self._parameters.items(): if param is not None: - destination[prefix + name] = gather_distributed_param(param, keep_vars=keep_vars) + destination[prefix + name] = gather_distributed_param(param, keep_vars=keep_vars).data for name, buf in self._buffers.items(): if buf is not None and name not in self._non_persistent_buffers_set: @@ -171,3 +175,187 @@ class ParallelModule(nn.Module, ABC): input_name = input_name.split(".", 1)[0] # get the name of param/buffer/child if input_name not in self._modules and input_name not in local_state: unexpected_keys.append(key) + + +class PaddingParallelModule(ParallelModule): + def __init__( + self, + new_num_embeddings: int, + old_num_embeddings: int, + weight: Optional[nn.Parameter], + bias_: Optional[nn.Parameter] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.new_num_embeddings = new_num_embeddings + self.old_num_embeddings = old_num_embeddings + self.weight = weight + self.bias = bias_ + + if not (is_distributed_tensor(self.weight) or self.weight.shape[0] == self.new_num_embeddings): + self.resize_embedding_weight() + + if self.bias is not None and not ( + is_distributed_tensor(self.bias) or self.bias.shape[0] == self.new_num_embeddings + ): + self.resize_embedding_bias() + + @abstractmethod + def from_native_module( + module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]] = None + ) -> "PaddingParallelModule": + """ + Convert a native PyTorch module to a parallelized module. + + Args: + module (nn.Module): the module to be converted. + process_group (ProcessGroup or list[ProcessGroup]): the process group(s) to be used for communication. + If this is a list, the process group at the ith index of the list will correspond to the process group + in the ith axis of the device mesh. Defaults to None, which means the global process group. + """ + raise NotImplementedError + + def _save_to_state_dict(self, destination, prefix, keep_vars): + r"""Saves module state to `destination` dictionary, containing a state + of the module, but not its descendants. This is called on every + submodule in :meth:`~torch.nn.Module.state_dict`. + + In rare cases, subclasses can achieve class-specific behavior by + overriding this method with custom logic. + + Args: + destination (dict): a dict where state will be stored + prefix (str): the prefix for parameters and buffers used in this + module + """ + for name, param in self._parameters.items(): + if param is not None: + param = gather_distributed_param(param, keep_vars=keep_vars) + if is_padded_tensor(param): + param = to_unpadded_tensor(param) + destination[prefix + name] = param.data + + for name, buf in self._buffers.items(): + if buf is not None and name not in self._non_persistent_buffers_set: + destination[prefix + name] = buf if keep_vars else buf.detach() + extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX + if getattr(self.__class__, "get_extra_state", Module.get_extra_state) is not Module.get_extra_state: + destination[extra_state_key] = self.get_extra_state() + + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): + r"""Copies parameters and buffers from :attr:`state_dict` into only + this module, but not its descendants. This is called on every submodule + in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this + module in input :attr:`state_dict` is provided as :attr:`local_metadata`. + For state dicts without metadata, :attr:`local_metadata` is empty. + Subclasses can achieve class-specific backward compatible loading using + the version number at `local_metadata.get("version", None)`. + + .. note:: + :attr:`state_dict` is not the same object as the input + :attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So + it can be modified. + + Args: + state_dict (dict): a dict containing parameters and + persistent buffers. + prefix (str): the prefix for parameters and buffers used in this + module + local_metadata (dict): a dict containing the metadata for this module. + See + strict (bool): whether to strictly enforce that the keys in + :attr:`state_dict` with :attr:`prefix` match the names of + parameters and buffers in this module + missing_keys (list of str): if ``strict=True``, add missing keys to + this list + unexpected_keys (list of str): if ``strict=True``, add unexpected + keys to this list + error_msgs (list of str): error messages should be added to this + list, and will be reported together in + :meth:`~torch.nn.Module.load_state_dict` + """ + for hook in self._load_state_dict_pre_hooks.values(): + hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + + persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set} + local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items()) + local_state = {k: v for k, v in local_name_params if v is not None} + + for name, param in local_state.items(): + key = prefix + name + + if key in state_dict: + input_param = state_dict[key] + if not torch.overrides.is_tensor_like(input_param): + error_msgs.append( + 'While copying the parameter named "{}", ' + "expected torch.Tensor or Tensor-like object from checkpoint but " + "received {}".format(key, type(input_param)) + ) + continue + + if is_padded_tensor(param): + input_param = to_padded_tensor(input_param, param._current_length, param._padding_dim) + + if is_distributed_tensor(param): + # shard the input param + device_mesh = get_device_mesh(param) + sharding_spec = get_sharding_spec(param) + sharded_tensor = distribute_tensor(input_param, device_mesh, sharding_spec) + input_param = sharded_tensor_to_param(sharded_tensor) + elif is_customized_distributed_tensor(param): + input_param = distribute_tensor_with_customization(input_param, param.shard_fn, param.gather_fn) + + # This is used to avoid copying uninitialized parameters into + # non-lazy modules, since they dont have the hook to do the checks + # in such case, it will error when accessing the .shape attribute. + is_param_lazy = torch.nn.parameter.is_lazy(param) + # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+ + if not is_param_lazy and len(param.shape) == 0 and len(input_param.shape) == 1: + input_param = input_param[0] + + if not is_param_lazy and input_param.shape != param.shape: + # local shape should match the one in checkpoint + error_msgs.append( + "size mismatch for {}: copying a param with shape {} from checkpoint, " + "the shape in current model is {}.".format(key, input_param.shape, param.shape) + ) + continue + + try: + with torch.no_grad(): + param.copy_(input_param) + except Exception as ex: + error_msgs.append( + 'While copying the parameter named "{}", ' + "whose dimensions in the model are {} and " + "whose dimensions in the checkpoint are {}, " + "an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args) + ) + elif strict: + missing_keys.append(key) + + extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX + if getattr(self.__class__, "set_extra_state", Module.set_extra_state) is not Module.set_extra_state: + if extra_state_key in state_dict: + self.set_extra_state(state_dict[extra_state_key]) + elif strict: + missing_keys.append(extra_state_key) + elif strict and (extra_state_key in state_dict): + unexpected_keys.append(extra_state_key) + + if strict: + for key in state_dict.keys(): + if key.startswith(prefix) and key != extra_state_key: + input_name = key[len(prefix) :] + input_name = input_name.split(".", 1)[0] # get the name of param/buffer/child + if input_name not in self._modules and input_name not in local_state: + unexpected_keys.append(key) + + def resize_embedding_weight(self): + self.weight = to_padded_tensor(self.weight, self.new_num_embeddings, 0) + + def resize_embedding_bias(self): + self.bias = to_padded_tensor(self.bias, self.new_num_embeddings, 0) diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index 1306c8aa6..26088569a 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -26,7 +26,6 @@ from colossalai.shardformer.layer._operation import gather_forward_split_backwar from colossalai.shardformer.shard import ShardConfig from ..layer import cross_entropy_1d -from ..layer._operation import gather_forward_split_backward logger = logging.get_logger(__name__) @@ -397,13 +396,11 @@ class GPT2PipelineForwards: shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group, + vocab_size=self.lm_head.out_features, ) else: loss = loss_fct(shift_logits, shift_labels) - if not shard_config.parallel_output: - lm_logits = gather_forward_split_backward(lm_logits, -1, shard_config.tensor_parallel_process_group) - if not return_dict: output = (lm_logits,) + outputs[1:] return ((loss,) + output) if loss is not None else output @@ -1301,12 +1298,12 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): shift_logits = shift_logits.view(-1, shift_logits.size(-1)) shift_labels = shift_labels.view(-1) loss = cross_entropy_1d( - shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group + shift_logits, + shift_labels, + process_group=shard_config.tensor_parallel_process_group, + vocab_size=self.lm_head.out_features, ) - if not shard_config.parallel_output: - lm_logits = gather_forward_split_backward(lm_logits, -1, shard_config.tensor_parallel_process_group) - if not return_dict: output = (lm_logits,) + transformer_outputs[1:] return ((loss,) + output) if loss is not None else output diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 0f1b4ad0a..c3b5426c2 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -316,7 +316,10 @@ class LlamaPipelineForwards: new_vocab_size = logits.shape[-1] shift_logits = shift_logits.view(-1, new_vocab_size) loss = cross_entropy_1d( - shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group + shift_logits, + shift_labels, + process_group=shard_config.tensor_parallel_process_group, + vocab_size=self.lm_head.out_features, ) else: shift_logits = shift_logits.view(-1, self.config.vocab_size) @@ -735,11 +738,13 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): shift_labels = shift_labels.view(-1) # Enable model parallelism shift_labels = shift_labels.to(shift_logits.device) - new_vocab_size = logits.shape[-1] shift_logits = shift_logits.view(-1, new_vocab_size) loss = cross_entropy_1d( - shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group + shift_logits, + shift_labels, + process_group=shard_config.tensor_parallel_process_group, + vocab_size=self.lm_head.out_features, ) if not return_dict: diff --git a/colossalai/shardformer/policies/base_policy.py b/colossalai/shardformer/policies/base_policy.py index d67ab0a3c..e976672bb 100644 --- a/colossalai/shardformer/policies/base_policy.py +++ b/colossalai/shardformer/policies/base_policy.py @@ -195,3 +195,12 @@ class Policy(ABC): List[Dict[int, Tensor]]: List of parameters that should be shared across stages. E.g. [{0: module.model.embed_tokens.weight, 3: module.lm_head.weight}] """ return [] + + def tie_weight_check(self): + input_embedding = self.model.get_input_embeddings() + output_embedding = self.model.get_output_embeddings() + return ( + input_embedding is not None + and output_embedding is not None + and id(input_embedding.weight) == id(output_embedding.weight) + ) diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 0a61d8cff..d43fc893a 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -37,17 +37,7 @@ class BertPolicy(Policy): pass def preprocess(self): - # reshape the embedding layer - r""" - Reshape the Embedding layer to make the embedding dimension divisible by world_size - """ - # TODO: - if self.shard_config.enable_tensor_parallelism: - vocab_size = self.model.config.vocab_size - world_size = self.shard_config.tensor_parallel_size - if vocab_size % world_size != 0: - new_vocab_size = vocab_size + world_size - vocab_size % world_size - self.model.resize_token_embeddings(new_vocab_size) + self.tie_weight = self.tie_weight_check() return self.model def module_policy(self): @@ -62,6 +52,13 @@ class BertPolicy(Policy): policy = {} + embedding_cls = None + if self.shard_config.enable_tensor_parallelism: + embedding_cls = col_nn.VocabParallelEmbedding1D + else: + if self.tie_weight: + embedding_cls = col_nn.PaddingEmbedding + if self.shard_config.enable_fused_normalization: norm_cls = col_nn.FusedLayerNorm else: @@ -150,10 +147,6 @@ class BertPolicy(Policy): policy[BertEmbeddings] = ModulePolicyDescription( sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="word_embeddings", - target_module=col_nn.VocabParallelEmbedding1D, - ), SubModuleReplacementDescription( suffix="dropout", target_module=col_nn.DropoutForReplicatedInput, @@ -168,6 +161,18 @@ class BertPolicy(Policy): target_key=BertModel, ) + if embedding_cls is not None: + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="word_embeddings", + target_module=embedding_cls, + ) + ], + policy=policy, + target_key=BertEmbeddings, + ) + # optimization configuration # Handle bert layer self.append_or_create_submodule_replacement( @@ -237,8 +242,21 @@ class BertPolicy(Policy): self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( suffix="decoder", - target_module=col_nn.Linear1D_Col, - kwargs={"gather_output": True}, + target_module=col_nn.VocabParallelLMHead1D, + kwargs={ + "gather_output": True, + "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + }, + ), + policy=base_policy, + target_key=BertLMPredictionHead, + ) + else: + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="decoder", + target_module=col_nn.PaddingLMHead, + kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, ), policy=base_policy, target_key=BertLMPredictionHead, diff --git a/colossalai/shardformer/policies/blip2.py b/colossalai/shardformer/policies/blip2.py index 9be2a1e78..b845e9336 100644 --- a/colossalai/shardformer/policies/blip2.py +++ b/colossalai/shardformer/policies/blip2.py @@ -17,16 +17,7 @@ class BlipPolicy(Policy): pass def preprocess(self): - # reshape the embedding layer - r""" - Reshape the Embedding layer to make the embedding dimension divisible by world_size - """ - # TODO: - vocab_size = self.model.config.qformer_config.vocab_size - world_size = self.shard_config.tensor_parallel_size - if vocab_size % world_size != 0: - new_vocab_size = vocab_size + world_size - vocab_size % world_size - self.model.resize_token_embeddings(new_vocab_size) + self.tie_weight = self.tie_weight_check() return self.model def module_policy(self): @@ -43,6 +34,13 @@ class BlipPolicy(Policy): policy = {} + embedding_cls = None + if self.shard_config.enable_tensor_parallelism: + embedding_cls = col_nn.VocabParallelEmbedding1D + else: + if self.tie_weight: + embedding_cls = col_nn.PaddingEmbedding + if self.shard_config.enable_fused_normalization: norm_cls = col_nn.FusedLayerNorm else: @@ -202,22 +200,48 @@ class BlipPolicy(Policy): ], ) - policy[OPTForCausalLM] = ModulePolicyDescription( - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="model.decoder.embed_tokens", - target_module=col_nn.VocabParallelEmbedding1D, - ), - SubModuleReplacementDescription( - suffix="lm_head", - target_module=col_nn.Linear1D_Col, - kwargs={"gather_output": True}, - ), - ] - ) - policy[Blip2Attention] = ModulePolicyDescription(method_replacement={"forward": forward_fn()}) + if embedding_cls is not None: + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="model.decoder.embed_tokens", + target_module=embedding_cls, + kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + ), + ], + policy=policy, + target_key=OPTForCausalLM, + ) + + if self.shard_config.enable_tensor_parallelism: + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="lm_head", + target_module=col_nn.VocabParallelLMHead1D, + kwargs={ + "gather_output": True, + "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + }, + ), + ], + policy=policy, + target_key=OPTForCausalLM, + ) + else: + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="lm_head", + target_module=col_nn.PaddingLMHead, + kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + ), + ], + policy=policy, + target_key=OPTForCausalLM, + ) # optimization configuration # Handle Blip2EncoderLayer layer self.append_or_create_submodule_replacement( diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index 2becadc3f..953592abc 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -35,16 +35,7 @@ class BloomPolicy(Policy): pass def preprocess(self): - # reshape the embedding layer - r""" - Reshape the Embedding layer to make the embedding dimension divisible by world_size - """ - if self.shard_config.enable_tensor_parallelism: - vocab_size = self.model.config.vocab_size - world_size = self.shard_config.tensor_parallel_size - if vocab_size % world_size != 0: - new_vocab_size = vocab_size + world_size - vocab_size % world_size - self.model.resize_token_embeddings(new_vocab_size) + self.tie_weight = self.tie_weight_check() return self.model def module_policy(self): @@ -52,6 +43,13 @@ class BloomPolicy(Policy): policy = {} + embedding_cls = None + if self.shard_config.enable_tensor_parallelism: + embedding_cls = col_nn.VocabParallelEmbedding1D + else: + if self.tie_weight: + embedding_cls = col_nn.PaddingEmbedding + if self.shard_config.enable_fused_normalization: norm_cls = col_nn.FusedLayerNorm else: @@ -112,12 +110,19 @@ class BloomPolicy(Policy): method_replacement={ "build_alibi_tensor": build_bloom_alibi_tensor_fn(self.shard_config.tensor_parallel_process_group) }, - sub_module_replacement=[ + ) + + if embedding_cls is not None: + self.append_or_create_submodule_replacement( + description=[ SubModuleReplacementDescription( suffix="word_embeddings", - target_module=col_nn.VocabParallelEmbedding1D, - ) + target_module=embedding_cls, + kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + ), ], + policy=policy, + target_key=BloomModel, ) # optimization configuration @@ -282,7 +287,21 @@ class BloomForCausalLMPolicy(BloomPolicy): if self.shard_config.enable_tensor_parallelism: self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( - suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True) + suffix="lm_head", + target_module=col_nn.VocabParallelLMHead1D, + kwargs=dict( + gather_output=True, make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by + ), + ), + policy=policy, + target_key=BloomForCausalLM, + ) + else: + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="lm_head", + target_module=col_nn.PaddingLMHead, + kwargs=dict(make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by), ), policy=policy, target_key=BloomForCausalLM, diff --git a/colossalai/shardformer/policies/chatglm2.py b/colossalai/shardformer/policies/chatglm2.py index dabc14bff..f205835e7 100644 --- a/colossalai/shardformer/policies/chatglm2.py +++ b/colossalai/shardformer/policies/chatglm2.py @@ -25,20 +25,12 @@ class ChatGLMPolicy(Policy): pass def preprocess(self): - # Resize embedding - if self.shard_config.enable_tensor_parallelism: - vocab_size = self.model.config.padded_vocab_size - world_size = self.shard_config.tensor_parallel_size - - if vocab_size % world_size != 0: - new_vocab_size = vocab_size + world_size - vocab_size % world_size - self.model.resize_token_embeddings(new_vocab_size) - if self.pipeline_stage_manager is not None: # the batch_size_dim is bounded to Model bsz_dim = 1 setattr(self.model, "batch_size_dim", bsz_dim) + self.tie_weight = self.tie_weight_check() return self.model def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: @@ -46,6 +38,13 @@ class ChatGLMPolicy(Policy): policy = {} + embedding_cls = None + if self.shard_config.enable_tensor_parallelism: + embedding_cls = col_nn.VocabParallelEmbedding1D + else: + if self.tie_weight: + embedding_cls = col_nn.PaddingEmbedding + if self.shard_config.enable_fused_normalization: if self.model.config.rmsnorm: norm_cls = col_nn.FusedRMSNorm @@ -68,16 +67,6 @@ class ChatGLMPolicy(Policy): sp_partial_derived = sp_mode == "split_gather" if self.shard_config.enable_tensor_parallelism: - policy[ChatGLMModel] = ModulePolicyDescription( - attribute_replacement={}, - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="embedding.word_embeddings", - target_module=col_nn.VocabParallelEmbedding1D, - ) - ], - ) - policy[GLMBlock] = ModulePolicyDescription( attribute_replacement={ "self_attention.num_attention_heads_per_partition": self.model.config.num_attention_heads @@ -114,6 +103,19 @@ class ChatGLMPolicy(Policy): ), ], ) + + if embedding_cls is not None: + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="embedding.word_embeddings", + target_module=embedding_cls, + kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + ), + ], + policy=policy, + target_key=ChatGLMModel, + ) # optimization configuration self.append_or_create_submodule_replacement( description=[ diff --git a/colossalai/shardformer/policies/falcon.py b/colossalai/shardformer/policies/falcon.py index fe61c406f..a2f110a41 100644 --- a/colossalai/shardformer/policies/falcon.py +++ b/colossalai/shardformer/policies/falcon.py @@ -32,16 +32,7 @@ class FalconPolicy(Policy): pass def preprocess(self): - # reshape the embedding layer - r""" - Reshape the Embedding layer to make the embedding dimension divisible by world_size - """ - if self.shard_config.enable_tensor_parallelism: - vocab_size = self.model.config.vocab_size - world_size = self.shard_config.tensor_parallel_size - if vocab_size % world_size != 0: - new_vocab_size = vocab_size + world_size - vocab_size % world_size - self.model.resize_token_embeddings(new_vocab_size) + self.tie_weight = self.tie_weight_check() return self.model def module_policy(self): @@ -58,6 +49,14 @@ class FalconPolicy(Policy): warnings.warn("Falcon doesn't support sequence parallelism now, will ignore the sequence parallelism flag.") policy = {} + + embedding_cls = None + if self.shard_config.enable_tensor_parallelism: + embedding_cls = col_nn.VocabParallelEmbedding1D + else: + if self.tie_weight: + embedding_cls = col_nn.PaddingEmbedding + if self.shard_config.enable_tensor_parallelism: attn_attribute_replacement = { "self_attention.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, @@ -98,12 +97,19 @@ class FalconPolicy(Policy): method_replacement={ "build_alibi_tensor": build_falcon_alibi_tensor_fn(self.shard_config.tensor_parallel_process_group) }, - sub_module_replacement=[ + ) + + if embedding_cls is not None: + self.append_or_create_submodule_replacement( + description=[ SubModuleReplacementDescription( suffix="word_embeddings", - target_module=col_nn.VocabParallelEmbedding1D, - ) + target_module=embedding_cls, + kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + ), ], + policy=policy, + target_key=FalconModel, ) # optimization configuration @@ -232,11 +238,26 @@ class FalconForCausalLMPolicy(FalconPolicy): if self.shard_config.enable_tensor_parallelism: self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( - suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True) + suffix="lm_head", + target_module=col_nn.VocabParallelLMHead1D, + kwargs=dict( + gather_output=True, make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by + ), ), policy=policy, target_key=FalconForCausalLM, ) + else: + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="lm_head", + target_module=col_nn.PaddingLMHead, + kwargs=dict(make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by), + ), + policy=policy, + target_key=FalconForCausalLM, + ) + if self.pipeline_stage_manager: self.set_pipeline_forward( model_cls=FalconForCausalLM, diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 380a432dc..98db7b948 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -34,12 +34,7 @@ class GPT2Policy(Policy): r""" Reshape the Embedding layer to make the embedding dimension divisible by world_size """ - if self.shard_config.enable_tensor_parallelism: - vocab_size = self.model.config.vocab_size - world_size = self.shard_config.tensor_parallel_size - if vocab_size % world_size != 0: - new_vocab_size = vocab_size + world_size - vocab_size % world_size - self.model.resize_token_embeddings(new_vocab_size) + self.tie_weight = self.tie_weight_check() return self.model def module_policy(self): @@ -47,6 +42,13 @@ class GPT2Policy(Policy): policy = {} + embedding_cls = None + if self.shard_config.enable_tensor_parallelism: + embedding_cls = col_nn.VocabParallelEmbedding1D + else: + if self.tie_weight: + embedding_cls = col_nn.PaddingEmbedding + if self.shard_config.enable_fused_normalization: norm_cls = col_nn.FusedLayerNorm else: @@ -73,10 +75,6 @@ class GPT2Policy(Policy): if self.shard_config.enable_tensor_parallelism: policy[GPT2Model] = ModulePolicyDescription( sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="wte", - target_module=col_nn.VocabParallelEmbedding1D, - ), SubModuleReplacementDescription( suffix="drop", target_module=col_nn.DropoutForParallelInput, @@ -137,6 +135,17 @@ class GPT2Policy(Policy): ), ], ) + if embedding_cls is not None: + # padding vocabulary size when using pp to make it divisible by shard_config.make_vocab_size_divisible_by + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="wte", + target_module=embedding_cls, + kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + ), + policy=policy, + target_key=GPT2Model, + ) # optimization configuration self.append_or_create_submodule_replacement( @@ -298,8 +307,11 @@ class GPT2LMHeadModelPolicy(GPT2Policy): sub_module_replacement=[ SubModuleReplacementDescription( suffix="lm_head", - target_module=col_nn.Linear1D_Col, - kwargs={"gather_output": not self.shard_config.parallel_output}, + target_module=col_nn.VocabParallelLMHead1D, + kwargs={ + "gather_output": False, + "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + }, ) ], ) @@ -308,7 +320,19 @@ class GPT2LMHeadModelPolicy(GPT2Policy): addon_module[GPT2LMHeadModel].method_replacement = { "forward": get_lm_forward_with_dist_cross_entropy(self.shard_config) } - module_policy.update(addon_module) + else: + addon_module = { + GPT2LMHeadModel: ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="lm_head", + target_module=col_nn.PaddingLMHead, + kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + ) + ] + ) + } + module_policy.update(addon_module) if self.pipeline_stage_manager is not None: self.set_pipeline_forward( @@ -353,13 +377,28 @@ class GPT2DoubleHeadsModelPolicy(GPT2Policy): sub_module_replacement=[ SubModuleReplacementDescription( suffix="lm_head", - target_module=col_nn.Linear1D_Col, - kwargs={"gather_output": True}, + target_module=col_nn.VocabParallelLMHead1D, + kwargs={ + "gather_output": True, + "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + }, ) ] ) } - module_policy.update(addon_module) + else: + addon_module = { + GPT2DoubleHeadsModel: ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="lm_head", + target_module=col_nn.PaddingLMHead, + kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + ) + ] + ) + } + module_policy.update(addon_module) if self.pipeline_stage_manager is not None: self.set_pipeline_forward( diff --git a/colossalai/shardformer/policies/gptj.py b/colossalai/shardformer/policies/gptj.py index eab4c214a..4b69137a6 100644 --- a/colossalai/shardformer/policies/gptj.py +++ b/colossalai/shardformer/policies/gptj.py @@ -29,22 +29,21 @@ class GPTJPolicy(Policy): pass def preprocess(self): - # reshape the embedding layer - r""" - Reshape the Embedding layer to make the embedding dimension divisible by world_size - """ - if self.shard_config.enable_tensor_parallelism: - vocab_size = self.model.config.vocab_size - world_size = self.shard_config.tensor_parallel_size - if vocab_size % world_size != 0: - new_vocab_size = vocab_size + world_size - vocab_size % world_size - self.model.resize_token_embeddings(new_vocab_size) + self.tie_weight = self.tie_weight_check() return self.model def module_policy(self): from transformers.models.gptj.modeling_gptj import GPTJAttention, GPTJBlock, GPTJModel policy = {} + + embedding_cls = None + if self.shard_config.enable_tensor_parallelism: + embedding_cls = col_nn.VocabParallelEmbedding1D + else: + if self.tie_weight: + embedding_cls = col_nn.PaddingEmbedding + if self.shard_config.enable_sequence_parallelism: self.shard_config.enable_sequence_parallelism = False warnings.warn("GPTJ doesn't support sequence parallelism now, will ignore the sequence parallelism flag.") @@ -54,10 +53,6 @@ class GPTJPolicy(Policy): if self.shard_config.enable_tensor_parallelism: policy[GPTJModel] = ModulePolicyDescription( sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="wte", - target_module=col_nn.VocabParallelEmbedding1D, - ), SubModuleReplacementDescription( suffix="drop", target_module=col_nn.DropoutForParallelInput, @@ -126,6 +121,17 @@ class GPTJPolicy(Policy): ], ) + if embedding_cls is not None: + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="wte", + target_module=embedding_cls, + kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + ), + policy=policy, + target_key=GPTJModel, + ) + # optimization configuration if self.shard_config.enable_fused_normalization: self.append_or_create_submodule_replacement( @@ -255,13 +261,28 @@ class GPTJForCausalLMPolicy(GPTJPolicy): sub_module_replacement=[ SubModuleReplacementDescription( suffix="lm_head", - target_module=col_nn.Linear1D_Col, - kwargs={"gather_output": True}, + target_module=col_nn.VocabParallelLMHead1D, + kwargs={ + "gather_output": True, + "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + }, ) ] ) } - policy.update(addon_module) + else: + addon_module = { + GPTJForCausalLM: ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="lm_head", + target_module=col_nn.PaddingLMHead, + kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + ) + ] + ) + } + policy.update(addon_module) if self.pipeline_stage_manager is not None: self.set_pipeline_forward( diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index bb4551b2c..ff686a179 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -6,7 +6,16 @@ import torch.nn as nn from torch import Tensor from torch.nn import Module -from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, RMSNorm, VocabParallelEmbedding1D +from colossalai.shardformer.layer import ( + FusedRMSNorm, + Linear1D_Col, + Linear1D_Row, + PaddingEmbedding, + PaddingLMHead, + RMSNorm, + VocabParallelEmbedding1D, + VocabParallelLMHead1D, +) from ..modeling.llama import ( LlamaPipelineForwards, @@ -26,15 +35,7 @@ class LlamaPolicy(Policy): pass def preprocess(self): - if self.shard_config.enable_tensor_parallelism: - # Resize embedding - vocab_size = self.model.config.vocab_size - world_size = self.shard_config.tensor_parallel_size - - if vocab_size % world_size != 0: - new_vocab_size = vocab_size + world_size - vocab_size % world_size - self.model.resize_token_embeddings(new_vocab_size) - + self.tie_weight = self.tie_weight_check() return self.model def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: @@ -42,6 +43,13 @@ class LlamaPolicy(Policy): policy = {} + embedding_cls = None + if self.shard_config.enable_tensor_parallelism: + embedding_cls = VocabParallelEmbedding1D + else: + if self.tie_weight: + embedding_cls = PaddingEmbedding + if self.shard_config.enable_fused_normalization: norm_cls = FusedRMSNorm else: @@ -167,10 +175,12 @@ class LlamaPolicy(Policy): ], ) + if embedding_cls is not None: self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( suffix="embed_tokens", - target_module=VocabParallelEmbedding1D, + target_module=embedding_cls, + kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, ), policy=policy, target_key=LlamaModel, @@ -327,8 +337,11 @@ class LlamaForCausalLMPolicy(LlamaPolicy): sub_module_replacement=[ SubModuleReplacementDescription( suffix="lm_head", - target_module=Linear1D_Col, - kwargs={"gather_output": not self.shard_config.parallel_output}, + target_module=VocabParallelLMHead1D, + kwargs={ + "gather_output": not self.shard_config.parallel_output, + "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + }, ) ], ) @@ -337,7 +350,19 @@ class LlamaForCausalLMPolicy(LlamaPolicy): new_item[LlamaForCausalLM].method_replacement = { "forward": get_lm_forward_with_dist_cross_entropy(self.shard_config) } - policy.update(new_item) + else: + new_item = { + LlamaForCausalLM: ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="lm_head", + target_module=PaddingLMHead, + kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + ) + ], + ) + } + policy.update(new_item) if self.pipeline_stage_manager: # set None as default diff --git a/colossalai/shardformer/policies/mistral.py b/colossalai/shardformer/policies/mistral.py index c0b8b3375..b225fd2a9 100644 --- a/colossalai/shardformer/policies/mistral.py +++ b/colossalai/shardformer/policies/mistral.py @@ -3,7 +3,15 @@ from typing import Dict, Union import torch.nn as nn -from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D +from colossalai.shardformer.layer import ( + FusedRMSNorm, + Linear1D_Col, + Linear1D_Row, + PaddingEmbedding, + PaddingLMHead, + VocabParallelEmbedding1D, + VocabParallelLMHead1D, +) from ..modeling.mistral import get_mistral_flash_attention_forward from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription @@ -16,15 +24,7 @@ class MistralPolicy(Policy): pass def preprocess(self): - if self.shard_config.enable_tensor_parallelism: - # Resize embedding - vocab_size = self.model.config.vocab_size - world_size = self.shard_config.tensor_parallel_size - - if vocab_size % world_size != 0: - new_vocab_size = vocab_size + world_size - vocab_size % world_size - self.model.resize_token_embeddings(new_vocab_size) - + self.tie_weight = self.tie_weight_check() return self.model def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: @@ -32,6 +32,13 @@ class MistralPolicy(Policy): policy = {} + embedding_cls = None + if self.shard_config.enable_tensor_parallelism: + embedding_cls = VocabParallelEmbedding1D + else: + if self.tie_weight: + embedding_cls = PaddingEmbedding + if self.shard_config.enable_sequence_parallelism: self.shard_config.enable_sequence_parallelism = False warnings.warn( @@ -80,10 +87,12 @@ class MistralPolicy(Policy): ], ) + if embedding_cls is not None: self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( suffix="embed_tokens", - target_module=VocabParallelEmbedding1D, + target_module=embedding_cls, + kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, ), policy=policy, target_key=MistralModel, @@ -146,6 +155,8 @@ class MistralForCausalLMPolicy(MistralPolicy): from transformers import MistralForCausalLM policy = super().module_policy() + if self.pipeline_stage_manager: + warnings.warn("Mistral doesn't support pipeline parallelism now.") if self.shard_config.enable_tensor_parallelism: # add a new item for casual lm @@ -153,16 +164,30 @@ class MistralForCausalLMPolicy(MistralPolicy): MistralForCausalLM: ModulePolicyDescription( sub_module_replacement=[ SubModuleReplacementDescription( - suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True) + suffix="lm_head", + target_module=VocabParallelLMHead1D, + kwargs=dict( + gather_output=True, + make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by, + ), + ) + ] + ) + } + else: + new_item = { + MistralForCausalLM: ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="lm_head", + target_module=PaddingLMHead, + kwargs=dict(make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by), ) ] ) } - if self.pipeline_stage_manager: - warnings.warn("Mistral doesn't support pipeline parallelism now.") - - policy.update(new_item) + policy.update(new_item) return policy diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index 98e584be8..ac78ff6a7 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -5,7 +5,16 @@ from typing import Callable, Dict, List import torch.nn as nn from torch import Tensor, nn -from colossalai.shardformer.layer import FusedLayerNorm, LayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D +from colossalai.shardformer.layer import ( + FusedLayerNorm, + LayerNorm, + Linear1D_Col, + Linear1D_Row, + PaddingEmbedding, + PaddingLMHead, + VocabParallelEmbedding1D, + VocabParallelLMHead1D, +) from .._utils import getattr_ from ..modeling.jit import get_jit_fused_dropout_add_func @@ -41,16 +50,7 @@ class OPTPolicy(Policy): pass def preprocess(self): - # reshape the embedding layer - r""" - Reshape the Embedding layer to make the embedding dimension divisible by world_size - """ - if self.shard_config.enable_tensor_parallelism: - vocab_size = self.model.config.vocab_size - world_size = self.shard_config.tensor_parallel_size - if vocab_size % world_size != 0: - new_vocab_size = vocab_size + world_size - vocab_size % world_size - self.model.resize_token_embeddings(new_vocab_size) + self.tie_weight = self.tie_weight_check() return self.model def module_policy(self): @@ -58,6 +58,13 @@ class OPTPolicy(Policy): policy = {} + embedding_cls = None + if self.shard_config.enable_tensor_parallelism: + embedding_cls = VocabParallelEmbedding1D + else: + if self.tie_weight: + embedding_cls = PaddingEmbedding + if self.shard_config.enable_fused_normalization: norm_cls = FusedLayerNorm else: @@ -68,14 +75,6 @@ class OPTPolicy(Policy): warnings.warn("OPT doesn't support sequence parallelism now, will ignore the sequence parallelism flag.") if self.shard_config.enable_tensor_parallelism: - policy[OPTDecoder] = ModulePolicyDescription( - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="embed_tokens", - target_module=VocabParallelEmbedding1D, - ) - ] - ) policy[OPTDecoderLayer] = ModulePolicyDescription( sub_module_replacement=[ SubModuleReplacementDescription( @@ -114,6 +113,17 @@ class OPTPolicy(Policy): ], ) + if embedding_cls is not None: + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="embed_tokens", + target_module=embedding_cls, + kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + ), + policy=policy, + target_key=OPTDecoder, + ) + # optimization configuration self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( @@ -253,8 +263,20 @@ class OPTForCausalLMPolicy(OPTPolicy): self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( suffix="lm_head", - target_module=Linear1D_Col, - kwargs=dict(gather_output=True), + target_module=VocabParallelLMHead1D, + kwargs=dict( + gather_output=True, make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by + ), + ), + policy=policy, + target_key=OPTForCausalLM, + ) + else: + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="lm_head", + target_module=PaddingLMHead, + kwargs=dict(make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by), ), policy=policy, target_key=OPTForCausalLM, diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py index 0c8ec15fa..3c7e92b47 100644 --- a/colossalai/shardformer/policies/t5.py +++ b/colossalai/shardformer/policies/t5.py @@ -13,8 +13,11 @@ from colossalai.shardformer.layer import ( FusedRMSNorm, Linear1D_Col, Linear1D_Row, + PaddingEmbedding, + PaddingLMHead, RMSNorm, VocabParallelEmbedding1D, + VocabParallelLMHead1D, ) from colossalai.shardformer.policies.base_policy import ModulePolicyDescription @@ -36,16 +39,7 @@ class T5BasePolicy(Policy): pass def preprocess(self): - # reshape the embedding layer - r""" - Reshape the Embedding layer to make the embedding dimension divisible by world_size - """ - if self.shard_config.enable_tensor_parallelism: - vocab_size = self.model.config.vocab_size - world_size = self.shard_config.tensor_parallel_size - if vocab_size % world_size != 0: - new_vocab_size = vocab_size + world_size - vocab_size % world_size - self.model.resize_token_embeddings(new_vocab_size) + self.tie_weight = self.tie_weight_check() return self.model def module_policy(self): @@ -61,6 +55,13 @@ class T5BasePolicy(Policy): policy = {} + embedding_cls = None + if self.shard_config.enable_tensor_parallelism: + embedding_cls = VocabParallelEmbedding1D + else: + if self.tie_weight: + embedding_cls = PaddingEmbedding + if self.shard_config.enable_fused_normalization: norm_cls = FusedRMSNorm else: @@ -77,10 +78,6 @@ class T5BasePolicy(Policy): suffix="dropout", target_module=DropoutForParallelInput, ), - SubModuleReplacementDescription( - suffix="embed_tokens", - target_module=VocabParallelEmbedding1D, - ), ] ) policy[T5LayerSelfAttention] = ModulePolicyDescription( @@ -176,6 +173,17 @@ class T5BasePolicy(Policy): ] ) + if embedding_cls is not None: + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="embed_tokens", + target_module=embedding_cls, + kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + ), + policy=policy, + target_key=T5Stack, + ) + # optimization configuration self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( @@ -370,11 +378,19 @@ class T5ModelPolicy(T5BasePolicy): policy = super().module_policy() + embedding_cls = None if self.shard_config.enable_tensor_parallelism: + embedding_cls = VocabParallelEmbedding1D + else: + if self.tie_weight: + embedding_cls = PaddingEmbedding + + if embedding_cls is not None: self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( suffix="shared", - target_module=VocabParallelEmbedding1D, + target_module=embedding_cls, + kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, ), policy=policy, target_key=T5Model, @@ -406,17 +422,44 @@ class T5ForConditionalGenerationPolicy(T5BasePolicy): policy = super().module_policy() + embedding_cls = None + if self.shard_config.enable_tensor_parallelism: + embedding_cls = VocabParallelEmbedding1D + else: + if self.tie_weight: + embedding_cls = PaddingEmbedding + + if embedding_cls is not None: + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="shared", + target_module=embedding_cls, + kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + ), + policy=policy, + target_key=T5ForConditionalGeneration, + ) + if self.shard_config.enable_tensor_parallelism: self.append_or_create_submodule_replacement( - description=[ - SubModuleReplacementDescription( - suffix="shared", - target_module=VocabParallelEmbedding1D, - ), - SubModuleReplacementDescription( - suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True) - ), - ], + description=SubModuleReplacementDescription( + suffix="lm_head", + target_module=VocabParallelLMHead1D, + kwargs={ + "gather_output": True, + "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + }, + ), + policy=policy, + target_key=T5ForConditionalGeneration, + ) + else: + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="lm_head", + target_module=PaddingLMHead, + kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + ), policy=policy, target_key=T5ForConditionalGeneration, ) @@ -467,11 +510,19 @@ class T5EncoderPolicy(T5BasePolicy): policy = super().module_policy() + embedding_cls = None if self.shard_config.enable_tensor_parallelism: + embedding_cls = VocabParallelEmbedding1D + else: + if self.tie_weight: + embedding_cls = PaddingEmbedding + + if embedding_cls is not None: self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( suffix="shared", - target_module=VocabParallelEmbedding1D, + target_module=embedding_cls, + kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, ), policy=policy, target_key=T5EncoderModel, diff --git a/colossalai/shardformer/policies/whisper.py b/colossalai/shardformer/policies/whisper.py index c63f6d1cc..0b5114fa6 100644 --- a/colossalai/shardformer/policies/whisper.py +++ b/colossalai/shardformer/policies/whisper.py @@ -45,11 +45,7 @@ class WhisperPolicy(Policy): r""" Reshape the Embedding layer to make the embedding dimension divisible by world_size """ - vocab_size = self.model.config.vocab_size - world_size = self.shard_config.tensor_parallel_size - if vocab_size % world_size != 0: - new_vocab_size = vocab_size + world_size - vocab_size % world_size - self.model.resize_token_embeddings(new_vocab_size) + self.tie_weight = self.tie_weight_check() return self.model def module_policy(self): @@ -63,6 +59,13 @@ class WhisperPolicy(Policy): policy = {} + embedding_cls = None + if self.shard_config.enable_tensor_parallelism: + embedding_cls = col_nn.VocabParallelEmbedding1D + else: + if self.tie_weight: + embedding_cls = col_nn.PaddingEmbedding + if self.shard_config.enable_fused_normalization: norm_cls = col_nn.FusedLayerNorm else: @@ -167,13 +170,17 @@ class WhisperPolicy(Policy): ], ) - policy[WhisperDecoder] = ModulePolicyDescription( - sub_module_replacement=[ + if embedding_cls is not None: + self.append_or_create_submodule_replacement( + description=[ SubModuleReplacementDescription( suffix="embed_tokens", - target_module=col_nn.VocabParallelEmbedding1D, + target_module=embedding_cls, + kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, ), - ] + ], + policy=policy, + target_key=WhisperDecoder, ) # optimization configuration @@ -280,8 +287,21 @@ class WhisperPolicy(Policy): self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( suffix="proj_out", - target_module=col_nn.Linear1D_Col, - kwargs={"gather_output": True}, + target_module=col_nn.VocabParallelLMHead1D, + kwargs={ + "gather_output": True, + "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + }, + ), + policy=base_policy, + target_key=WhisperForConditionalGeneration, + ) + else: + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="proj_out", + target_module=col_nn.PaddingLMHead, + kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, ), policy=base_policy, target_key=WhisperForConditionalGeneration, @@ -526,9 +546,6 @@ class WhisperForConditionalGenerationPolicy(WhisperPolicy): # WhisperForAudioClassification class WhisperForAudioClassificationPolicy(WhisperPolicy): - def preprocess(self): - return self.model - def module_policy(self): from transformers import WhisperForAudioClassification diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 7489873c2..963732543 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -42,10 +42,9 @@ class ShardConfig: sequence_parallelism_mode: str = None enable_sequence_overlap: bool = False parallel_output: bool = True + make_vocab_size_divisible_by: int = 64 gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None extra_kwargs: Dict[str, Any] = field(default_factory=dict) - # TODO padding vocab - # make_vocab_size_divisible_by: int = 128 # pipeline_parallel_size: int # data_parallel_size: int # tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d'] diff --git a/colossalai/tensor/d_tensor/layout_converter.py b/colossalai/tensor/d_tensor/layout_converter.py index 667a7b78e..c2cf73181 100644 --- a/colossalai/tensor/d_tensor/layout_converter.py +++ b/colossalai/tensor/d_tensor/layout_converter.py @@ -10,6 +10,7 @@ from colossalai.context.singleton_meta import SingletonMeta from colossalai.tensor.d_tensor.comm_spec import * from colossalai.tensor.d_tensor.layout import Layout from colossalai.tensor.d_tensor.misc import LayoutException +from colossalai.tensor.padded_tensor.api import init_as_padded_tensor, is_padded_tensor from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, shard_simulator from .sharding_spec import ShardingSpec @@ -607,8 +608,18 @@ class LayoutConverter(metaclass=SingletonMeta): [3.], [3.]]) """ + _, comm_action_sequence = self.layout_converting(source_layout, target_layout) + + target_tensor = tensor for comm_spec in comm_action_sequence: - tensor = comm_spec.covert_spec_to_action(tensor) - tensor.dist_layout = target_layout - return tensor + target_tensor = comm_spec.covert_spec_to_action(target_tensor) + target_tensor.dist_layout = target_layout + + # restore the padding information + if is_padded_tensor(tensor) and not is_padded_tensor(target_tensor): + target_tensor = init_as_padded_tensor( + target_tensor, tensor._current_length, tensor._origin_length, tensor._padding_dim + ) + + return target_tensor diff --git a/colossalai/tensor/padded_tensor/__init__.py b/colossalai/tensor/padded_tensor/__init__.py new file mode 100644 index 000000000..353ff35f8 --- /dev/null +++ b/colossalai/tensor/padded_tensor/__init__.py @@ -0,0 +1,3 @@ +from .api import init_as_padded_tensor, is_padded_tensor, to_padded_tensor, to_unpadded_tensor + +__all__ = ["is_padded_tensor", "to_padded_tensor", "to_unpadded_tensor", "init_as_padded_tensor"] diff --git a/colossalai/tensor/padded_tensor/api.py b/colossalai/tensor/padded_tensor/api.py new file mode 100644 index 000000000..5b66c016b --- /dev/null +++ b/colossalai/tensor/padded_tensor/api.py @@ -0,0 +1,128 @@ +import torch + + +def _hijack_detach_and_clone(ptensor: torch.Tensor) -> torch.Tensor: + """ + Hijack the detach and clone methods of the tensor to make sure the dist_layout is copied. + + Args: + tensor (torch.Tensor): The tensor to be hijacked. + + Returns: + torch.Tensor: The hijacked tensor. + """ + ptensor._unpad_detach = ptensor.detach + ptensor._unpad_clone = ptensor.clone + + def new_detach(self): + t_ = self._unpad_detach() + t_._padding_dim = self._padding_dim + t_._origin_length = self._origin_length + t_._current_length = self._current_length + return t_ + + def new_clone(self, *args, **kwargs): + t_ = self._unpad_clone(*args, **kwargs) + t_._padding_dim = self._padding_dim + t_._origin_length = self._origin_length + t_._current_length = self._current_length + return t_ + + # bind the new methods to the tensor + ptensor.detach = new_detach.__get__(ptensor) + ptensor.clone = new_clone.__get__(ptensor) + return ptensor + + +def _hijack_back_detach_and_clone(ptensor: torch.Tensor) -> torch.Tensor: + """ + Hijack the detach and clone methods of the tensor to make sure the dist_layout is copied. + + Args: + tensor (torch.Tensor): The tensor to be hijacked. + + Returns: + torch.Tensor: The hijacked tensor. + """ + ptensor.detach = ptensor._unpad_detach + ptensor.clone = ptensor._unpad_clone + + delattr(ptensor, "_unpad_detach") + delattr(ptensor, "_unpad_clone") + + return ptensor + + +def is_padded_tensor(tensor: torch.Tensor) -> bool: + """ + Check whether the given tensor is a padding tensor. + + Args: + tensor (torch.Tensor): The tensor to be checked. + + Returns: + bool: Whether the given tensor is a padding tensor. + """ + return hasattr(tensor, "_padding_dim") + + +def to_padded_tensor( + tensor: torch.Tensor, + current_length: int, + padding_dim: int, +) -> torch.Tensor: + assert ( + padding_dim < tensor.dim() + ), f"Please passing a valid padding_dim. the dimension of the tensor is {tensor.dim()}" + + if is_padded_tensor(tensor): + return tensor + + origin_length = tensor.shape[padding_dim] + padding_num = current_length - origin_length + padding_data = torch.zeros( + *tensor.shape[:padding_dim], + padding_num, + *tensor.shape[padding_dim + 1 :], + device=tensor.device, + dtype=tensor.dtype, + ) + tensor.data = torch.cat((tensor.data, padding_data), dim=padding_dim).contiguous() + + tensor._padding_dim = padding_dim + tensor._origin_length = origin_length + tensor._current_length = current_length + + _hijack_detach_and_clone(tensor) + + return tensor + + +def to_unpadded_tensor(ptensor: torch.Tensor): + if not is_padded_tensor(ptensor): + return ptensor + + unpad_slices = [slice(None)] * ptensor.dim() + unpad_slices[ptensor._padding_dim] = slice(None, ptensor._origin_length) + ptensor.data = ptensor.data[tuple(unpad_slices)] + + delattr(ptensor, "_padding_dim") + delattr(ptensor, "_origin_length") + delattr(ptensor, "_current_length") + + _hijack_back_detach_and_clone(ptensor) + + return ptensor + + +def init_as_padded_tensor(tensor: torch.Tensor, current_length: int, origin_length: int, padding_dim: int): + if is_padded_tensor(tensor): + return tensor + + tensor._padding_dim = padding_dim + tensor._origin_length = origin_length + tensor._current_length = current_length + + _hijack_detach_and_clone(tensor) + + return tensor diff --git a/colossalai/testing/comparison.py b/colossalai/testing/comparison.py index e415b5fc3..bdf7b19f3 100644 --- a/colossalai/testing/comparison.py +++ b/colossalai/testing/comparison.py @@ -23,7 +23,7 @@ def assert_close_loose(a: Tensor, b: Tensor, rtol: float = 1e-3, atol: float = 1 rtol=rtol, atol=atol, msg=f"Tensor not close, shape: {a.shape} vs {b.shape}, \ - dtype: {a.dtype} vs {b.dtype}", + dtype: {a.dtype} vs {b.dtype}", ) diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index bc6c9d088..c79422171 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -27,6 +27,12 @@ from colossalai.tensor.d_tensor import ( is_customized_distributed_tensor, is_distributed_tensor, ) +from colossalai.tensor.padded_tensor import ( + init_as_padded_tensor, + is_padded_tensor, + to_padded_tensor, + to_unpadded_tensor, +) from colossalai.tensor.param_op_hook import ColoParamOpHookManager from colossalai.utils import _cast_float, free_storage, is_ddp_ignored @@ -460,6 +466,11 @@ class GeminiDDP(ModelWrapper): record_tensor, shard_fn=tensor.shard_fn, gather_fn=tensor.gather_fn ) record_tensor = gather_distributed_param(record_tensor, keep_vars=False).cpu() + if is_padded_tensor(tensor): + record_tensor = init_as_padded_tensor( + record_tensor, tensor._current_length, tensor._origin_length, tensor._padding_dim + ) + record_tensor = to_unpadded_tensor(record_tensor) assert tensor not in chunk_to_save_data chunk_to_save_data[tensor] = record_tensor @@ -520,6 +531,8 @@ class GeminiDDP(ModelWrapper): # deal with ddp ignored parameters destination[prefix + name] = param if keep_vars else param.detach() else: + if is_padded_tensor(p_mapping[param]): + p_mapping[param] = to_unpadded_tensor(p_mapping[param]) destination[prefix + name] = p_mapping[param] del p_mapping del param_to_save_data @@ -627,6 +640,7 @@ class GeminiDDP(ModelWrapper): list, and will be reported together in :meth:`~torch.nn.Module.load_state_dict` """ + for hook in self._load_state_dict_pre_hooks.values(): hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) @@ -647,6 +661,14 @@ class GeminiDDP(ModelWrapper): if state_key in state_dict: input_param = state_dict[state_key] + global_shape = dest_tensor.shape + if source_device_mesh is not None and source_sharding_spec is not None: + global_shape = get_global_shape(dest_tensor) + + if is_padded_tensor(dest_tensor): + padding_dim = dest_tensor._padding_dim + input_param = to_padded_tensor(input_param, global_shape[padding_dim], padding_dim) + if source_device_mesh is not None and source_sharding_spec is not None: input_param = distribute_tensor(input_param, source_device_mesh, source_sharding_spec) elif shard_fn is not None and gather_fn is not None: diff --git a/colossalai/zero/gemini/gemini_optimizer.py b/colossalai/zero/gemini/gemini_optimizer.py index 18367af59..ae02fe297 100644 --- a/colossalai/zero/gemini/gemini_optimizer.py +++ b/colossalai/zero/gemini/gemini_optimizer.py @@ -21,12 +21,19 @@ from colossalai.tensor.d_tensor import ( distribute_tensor, distribute_tensor_with_customization, get_device_mesh, + get_global_shape, get_sharding_spec, init_as_dtensor, init_tensor_as_customization_distributed, is_customized_distributed_tensor, is_distributed_tensor, ) +from colossalai.tensor.padded_tensor import ( + init_as_padded_tensor, + is_padded_tensor, + to_padded_tensor, + to_unpadded_tensor, +) from colossalai.utils import disposable, is_ddp_ignored from .chunk import Chunk, ChunkManager @@ -106,7 +113,7 @@ class GeminiOptimizer(OptimizerWrapper): max_norm: float = 0.0, norm_type: float = 2.0, tp_group: ProcessGroup = None, - optimizer_params_info=None, + params_info=None, verbose: bool = False, **defaults: Any, ): @@ -124,7 +131,7 @@ class GeminiOptimizer(OptimizerWrapper): self.clipping_flag = max_norm > 0.0 self.max_norm = max_norm self.tp_group = tp_group - self.optimizer_params_info = optimizer_params_info + self.params_info = params_info self.tp_size = dist.get_world_size(tp_group) if tp_group is not None else 1 self.tp_rank = dist.get_rank(tp_group) if tp_group is not None else 0 self.verbose = verbose @@ -459,7 +466,7 @@ class GeminiOptimizer(OptimizerWrapper): is_customized_distributed = is_customized_distributed_tensor(param) shard_spec = get_sharding_spec(param) if is_dtensor else None device_mesh = get_device_mesh(param) if is_dtensor else None - global_shape = self.optimizer_params_info["id2shape"][param_id] + global_shape = self.params_info["id2shape"][param_id] # If the chunk is kept gathered, # the parameters are treated the same as that of those in strict DDP during training. @@ -477,6 +484,7 @@ class GeminiOptimizer(OptimizerWrapper): else: state_tensor = states[state_name].detach().clone().to(torch.float32).cpu() if is_dtensor: + global_shape = get_global_shape(param) state_tensor = torch.reshape(state_tensor, param.shape).to(param.device) state_tensor = init_as_dtensor( state_tensor, @@ -490,8 +498,13 @@ class GeminiOptimizer(OptimizerWrapper): state_tensor, shard_fn=param.shard_fn, gather_fn=param.gather_fn ) state_tensor = gather_distributed_param(state_tensor, keep_vars=False).cpu() - - collected_states[state_name] = state_tensor.reshape(global_shape) + state_tensor = state_tensor.reshape(global_shape) + if is_padded_tensor(param): + state_tensor = init_as_padded_tensor( + state_tensor, param._current_length, param._origin_length, param._padding_dim + ) + state_tensor = to_unpadded_tensor(state_tensor) + collected_states[state_name] = state_tensor return collected_states # Check whether the param with given id is managed by current process. @@ -535,6 +548,7 @@ class GeminiOptimizer(OptimizerWrapper): if state_tensor.numel() == param.numel(): collected_states[state_name] = torch.reshape(state_tensor, param.shape) if is_dtensor: + global_shape = get_global_shape(param) state_tensor = state_tensor.to(param.device) state_tensor = init_as_dtensor( state_tensor, sharding_spec=shard_spec, device_mesh=device_mesh, global_shape=global_shape @@ -545,6 +559,11 @@ class GeminiOptimizer(OptimizerWrapper): state_tensor, shard_fn=param.shard_fn, gather_fn=param.gather_fn ) state_tensor = gather_distributed_param(state_tensor, keep_vars=False).cpu() + if is_padded_tensor(param): + state_tensor = init_as_padded_tensor( + state_tensor, param._current_length, param._origin_length, param._padding_dim + ) + state_tensor = to_unpadded_tensor(state_tensor) return collected_states @@ -698,7 +717,7 @@ class GeminiOptimizer(OptimizerWrapper): Load saved optimizer states into parameter with given id. """ - def cast(param, state_range, value, key=None): + def cast(param, state_range, value, global_shape, origin_shape, key=None): """ Make a copy of the needed segment of value and cast it to device of param. """ @@ -714,7 +733,14 @@ class GeminiOptimizer(OptimizerWrapper): ) if is_dtensor: - value = torch.reshape(value, global_shape) + global_shape = get_global_shape(real_param) + + if is_padded_tensor(real_param): + value = torch.reshape(value, origin_shape) + padding_dim = real_param._padding_dim + value = to_padded_tensor(value, global_shape[padding_dim], padding_dim) + + if is_dtensor: value = distribute_tensor(value, sharding_spec=shard_spec, device_mesh=device_mesh) elif is_customized_distributed: value = torch.reshape(value, global_shape) @@ -737,10 +763,11 @@ class GeminiOptimizer(OptimizerWrapper): is_customized_distributed = is_customized_distributed_tensor(real_param) shard_spec = get_sharding_spec(real_param) if is_dtensor else None device_mesh = get_device_mesh(real_param) if is_dtensor else None - global_shape = self.optimizer_params_info["id2shape"][param_id] + global_shape = self.params_info["id2shape"][param_id] + origin_shape = global_shape for k, v in saved_states.items(): - updated_states[k] = cast(fake_param, state_range, v, k) + updated_states[k] = cast(fake_param, state_range, v, global_shape, origin_shape, k) del v # clean loaded states self.optim.state[fake_param].update(updated_states) diff --git a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py index d8a625b98..4753ab637 100644 --- a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py @@ -81,8 +81,7 @@ def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_conf optimizer.backward(loss) optimizer.step() - for group in optimizer.param_groups: - group["lr"] = 0.1 + optimizer.zero_grad() with shared_tempdir() as tempdir: model_ckpt_path = f"{tempdir}/model" optimizer_ckpt_path = f"{tempdir}/optimizer" diff --git a/tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py b/tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py index b23a44f2d..91cc1a987 100644 --- a/tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py +++ b/tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py @@ -21,7 +21,7 @@ def check_vocab_embedding_1d(lazy_init: bool): dist_embedding_1d = VocabParallelEmbedding1D.from_native_module(embedding_copy, process_group=None) assert dist_embedding_1d.weight.shape == torch.Size([64, 32]) - assert dist_embedding_1d.num_embeddings == 64 + assert dist_embedding_1d.num_embeddings == 128 assert dist_embedding_1d.embedding_dim == 32 assert embedding_copy.weight is dist_embedding_1d.weight diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index d5fc2c30f..a77ba39a1 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -14,12 +14,14 @@ from torch.testing import assert_close from colossalai.booster import Booster from colossalai.booster.plugin import HybridParallelPlugin from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelModule +from colossalai.checkpoint_io.utils import gather_distributed_param from colossalai.lazy import LazyInitContext from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.shardformer._utils import getattr_ from colossalai.shardformer.policies.auto_policy import Policy from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor +from colossalai.tensor.padded_tensor.api import is_padded_tensor, to_unpadded_tensor def build_model( @@ -247,11 +249,10 @@ def check_weight( continue if is_distributed_tensor(sharded_weight) or is_customized_distributed_tensor(sharded_weight): - sharded_weight_list = [ - torch.zeros_like(sharded_weight).to("cuda") for _ in range(dist.get_world_size(tp_group)) - ] - dist.all_gather(sharded_weight_list, sharded_weight, tp_group) - sharded_weight = torch.cat(sharded_weight_list, dim=dim) + sharded_weight = gather_distributed_param(sharded_weight, keep_vars=False) + + if is_padded_tensor(sharded_weight): + sharded_weight = to_unpadded_tensor(sharded_weight) if verbose and dist.get_rank() == 0: print(f"'{suffix}' weight: {org_weight}, {sharded_weight}") diff --git a/tests/test_shardformer/test_model/test_shard_t5.py b/tests/test_shardformer/test_model/test_shard_t5.py index 9b22d54d7..a6fe2dd39 100644 --- a/tests/test_shardformer/test_model/test_shard_t5.py +++ b/tests/test_shardformer/test_model/test_shard_t5.py @@ -73,7 +73,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # check weights if test_config["precision"] == "fp32": - atol, rtol = 5e-4, 1e-3 + # TODO he precision in weight checking is too significant. + atol, rtol = 1e-3, 1e-3 else: atol, rtol = 5e-3, 5e-3 if stage_manager is None or stage_manager.is_first_stage(): diff --git a/tests/test_tensor/test_padded_tensor.py b/tests/test_tensor/test_padded_tensor.py new file mode 100644 index 000000000..31a267c15 --- /dev/null +++ b/tests/test_tensor/test_padded_tensor.py @@ -0,0 +1,46 @@ +import torch + +from colossalai.device.device_mesh import DeviceMesh +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.tensor.d_tensor import ShardingSpec, distribute_tensor, is_distributed_tensor, to_global +from colossalai.tensor.padded_tensor import is_padded_tensor, to_padded_tensor, to_unpadded_tensor +from colossalai.testing import rerun_if_address_is_in_use, spawn + + +def check_padded_tensor(rank, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + original_tensor = torch.rand(32, 64).to("cuda") + + device_mesh = DeviceMesh(torch.Tensor([0, 1, 2, 3]), (2, 2), init_process_group=True) + target_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict={0: [0]}) + d_tensor = distribute_tensor(original_tensor, device_mesh, target_sharding_spec) + + padded_tensor = to_padded_tensor(d_tensor, current_length=64, padding_dim=0) + assert padded_tensor.dist_layout == d_tensor.dist_layout + + tensor_copy = padded_tensor.clone() + assert is_padded_tensor(tensor_copy) + assert is_distributed_tensor(tensor_copy) + + tensor_detached = padded_tensor.detach() + assert is_padded_tensor(tensor_detached) + assert is_distributed_tensor(tensor_detached) + + unpadded_tensor = to_unpadded_tensor(padded_tensor) + assert unpadded_tensor.shape == d_tensor.shape + assert is_distributed_tensor(unpadded_tensor) + + global_tensor = to_global(unpadded_tensor) + assert global_tensor.shape == original_tensor.shape + + +@rerun_if_address_is_in_use() +def test_padded_tensor(): + world_size = 4 + spawn(check_padded_tensor, world_size) + + +if __name__ == "__main__": + test_padded_tensor()