From fbf33ecd019ce0e075b76b628e6e8a319cfc43e3 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Tue, 9 Jul 2024 18:05:20 +0800 Subject: [PATCH] [Feature] Enable PP + SP for llama (#5868) * fix cross-PP-stage position id length diff bug * fix typo * fix typo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * use a one cross entropy func for all shardformer models --------- Co-authored-by: Edenzzzz Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../booster/plugin/hybrid_parallel_plugin.py | 1 + colossalai/shardformer/layer/__init__.py | 3 +- colossalai/shardformer/layer/loss.py | 45 +++++++++++- colossalai/shardformer/modeling/bloom.py | 56 ++++---------- colossalai/shardformer/modeling/command.py | 55 +++----------- colossalai/shardformer/modeling/gpt2.py | 47 ++---------- colossalai/shardformer/modeling/llama.py | 73 +++++++------------ colossalai/shardformer/modeling/mistral.py | 48 ++---------- colossalai/shardformer/modeling/opt.py | 57 +++------------ colossalai/shardformer/modeling/qwen2.py | 47 ++---------- colossalai/shardformer/policies/llama.py | 8 -- .../test_model/test_shard_llama.py | 31 ++++---- 12 files changed, 148 insertions(+), 323 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index a3d6f1e74..485833398 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1205,6 +1205,7 @@ class HybridParallelPlugin(PipelinePluginBase): and self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all" ) + # sync gradients across DP * SP ranks if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all": dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis]) else: diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py index f17fad1b6..331e49729 100644 --- a/colossalai/shardformer/layer/__init__.py +++ b/colossalai/shardformer/layer/__init__.py @@ -3,7 +3,7 @@ from .attn import AttnMaskType, ColoAttention from .dropout import DropoutForParallelInput, DropoutForReplicatedInput from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D from .linear import Linear1D_Col, Linear1D_Row, PaddingLMHead, VocabParallelLMHead1D -from .loss import cross_entropy_1d +from .loss import cross_entropy_1d, dist_cross_entropy from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm from .parallel_module import ParallelModule from .qkv_fused_linear import FusedLinear1D_Col, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row @@ -18,6 +18,7 @@ __all__ = [ "DropoutForParallelInput", "DropoutForReplicatedInput", "cross_entropy_1d", + "dist_cross_entropy", "BaseLayerNorm", "LayerNorm", "RMSNorm", diff --git a/colossalai/shardformer/layer/loss.py b/colossalai/shardformer/layer/loss.py index a6d19edf5..cea2da03f 100644 --- a/colossalai/shardformer/layer/loss.py +++ b/colossalai/shardformer/layer/loss.py @@ -2,8 +2,11 @@ import torch import torch.distributed as dist from torch.autograd import Function from torch.distributed import ProcessGroup +from torch.nn import CrossEntropyLoss -__all__ = ["DistCrossEntropy", "cross_entropy_1d"] +from colossalai.shardformer.shard import ShardConfig + +__all__ = ["DistCrossEntropy", "cross_entropy_1d", "dist_cross_entropy"] class DistCrossEntropy(Function): @@ -132,3 +135,43 @@ def cross_entropy_1d( dtype: torch.dtype = None, ) -> torch.Tensor: return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group, vocab_size, dtype) + + +def dist_cross_entropy( + labels: torch.Tensor, + logits: torch.Tensor, + shard_config: ShardConfig, + out_features: int, + vocab_size: int, + dtype: torch.dtype, +) -> torch.Tensor: + """ + Helper to compute cross entropy loss for most shardformer models, + compatible with PP, TP and SP. + """ + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_labels = shift_labels.view(-1) + shift_labels = shift_labels.to(shift_logits.device) + if shard_config.enable_tensor_parallelism and shard_config.parallel_output: + # Cross entropy with all-reduce for TP + new_vocab_size = logits.shape[-1] + shift_logits = shift_logits.view(-1, new_vocab_size) + loss = cross_entropy_1d( + shift_logits, + shift_labels, + process_group=shard_config.tensor_parallel_process_group, + vocab_size=out_features, + dtype=dtype, + ) + else: + # NOTE if use TP and not parallel_output, the output is gathered. + # see VocabParallelLMHead1D + shift_logits = shift_logits.view(-1, vocab_size) + loss = loss_fct(shift_logits, shift_labels) + + return loss diff --git a/colossalai/shardformer/modeling/bloom.py b/colossalai/shardformer/modeling/bloom.py index 154143626..26ffef6c5 100644 --- a/colossalai/shardformer/modeling/bloom.py +++ b/colossalai/shardformer/modeling/bloom.py @@ -28,7 +28,7 @@ from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward from colossalai.shardformer.shard import ShardConfig -from ..layer import cross_entropy_1d +from ..layer import dist_cross_entropy logger = logging.get_logger(__name__) @@ -359,30 +359,14 @@ class BloomPipelineForwards: hidden_states = transformer_outputs[0] lm_logits = self.lm_head(hidden_states).contiguous() - loss = None - if labels is not None: - # move labels to correct device to enable model parallelism - labels = labels.to(lm_logits.device) - # Shift so that tokens < n predict n - shift_logits = lm_logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - batch_size, seq_length, vocab_size = shift_logits.shape - # Flatten the tokens - if shard_config.enable_tensor_parallelism and shard_config.parallel_output: - new_vocab_size = lm_logits.shape[-1] - shift_logits = shift_logits.view(-1, new_vocab_size) - shift_labels = shift_labels.view(-1) - loss = cross_entropy_1d( - shift_logits, - shift_labels, - process_group=shard_config.tensor_parallel_process_group, - vocab_size=self.lm_head.out_features, - dtype=self.transformer.dtype, - ) - else: - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - loss = loss_fct(shift_logits, shift_labels.view(-1)) + loss = dist_cross_entropy( + labels, + lm_logits, + shard_config, + self.lm_head.out_features, + self.config.vocab_size, + self.transformer.dtype, + ) if not return_dict: output = (lm_logits,) + transformer_outputs[1:] @@ -1040,24 +1024,10 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): hidden_states = transformer_outputs[0] lm_logits = self.lm_head(hidden_states) - loss = None - if labels is not None: - # move labels to correct device to enable model parallelism - labels = labels.to(lm_logits.device) - # Shift so that tokens < n predict n - shift_logits = lm_logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - new_vocab_size = lm_logits.shape[-1] - shift_logits = shift_logits.view(-1, new_vocab_size) - shift_labels = shift_labels.view(-1) - loss = cross_entropy_1d( - shift_logits, - shift_labels, - process_group=shard_config.tensor_parallel_process_group, - vocab_size=self.lm_head.out_features, - dtype=self.transformer.dtype, - ) + loss = dist_cross_entropy( + labels, lm_logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.transformer.dtype + ) + if not return_dict: output = (lm_logits,) + transformer_outputs[1:] return ((loss,) + output) if loss is not None else output diff --git a/colossalai/shardformer/modeling/command.py b/colossalai/shardformer/modeling/command.py index 07a7f6cbf..72f705bc0 100644 --- a/colossalai/shardformer/modeling/command.py +++ b/colossalai/shardformer/modeling/command.py @@ -5,7 +5,6 @@ from typing import List, Optional, Tuple, Union import torch import torch.utils.checkpoint from torch import nn -from torch.nn import CrossEntropyLoss from transformers.cache_utils import Cache, DynamicCache from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.models.cohere.modeling_cohere import ( @@ -25,7 +24,7 @@ from colossalai.shardformer.layer._operation import ( ) from colossalai.shardformer.shard import ShardConfig -from ..layer import ColoAttention, cross_entropy_1d +from ..layer import ColoAttention, dist_cross_entropy class CommandPipelineForwards: @@ -300,29 +299,9 @@ class CommandPipelineForwards: logits = self.lm_head(hidden_states) logits = logits * self.logit_scale logits = logits.float() - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - if shard_config.enable_tensor_parallelism and shard_config.parallel_output: - new_vocab_size = logits.shape[-1] - shift_logits = shift_logits.view(-1, new_vocab_size) - loss = cross_entropy_1d( - shift_logits, - shift_labels, - process_group=shard_config.tensor_parallel_process_group, - vocab_size=self.lm_head.out_features, - dtype=self.model.dtype, - ) - else: - shift_logits = shift_logits.view(-1, self.config.vocab_size) - loss = loss_fct(shift_logits, shift_labels) + loss = dist_cross_entropy( + labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype + ) if not return_dict: output = (logits,) + outputs[1:] @@ -658,24 +637,14 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): logits = self.lm_head(hidden_states) logits = logits * self.logit_scale logits = logits.float() - - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - new_vocab_size = logits.shape[-1] - shift_logits = shift_logits.view(-1, new_vocab_size) - loss = cross_entropy_1d( - shift_logits, - shift_labels, - process_group=shard_config.tensor_parallel_process_group, - vocab_size=self.lm_head.out_features, - dtype=self.model.dtype, - ) + loss = dist_cross_entropy( + labels, + logits, + shard_config, + self.lm_head.out_features, + self.config.vocab_size, + self.model.dtype, + ) if not return_dict: output = (logits,) + outputs[1:] diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index aa75bab11..6ecda91c4 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -25,7 +25,7 @@ from colossalai.shardformer.layer import ColoAttention from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward from colossalai.shardformer.shard import ShardConfig -from ..layer import cross_entropy_1d +from ..layer import dist_cross_entropy logger = logging.get_logger(__name__) @@ -372,27 +372,9 @@ class GPT2PipelineForwards: hidden_states = outputs[0] lm_logits = self.lm_head(hidden_states) - loss = None - if labels is not None: - # move labels to correct device to enable model parallelism - labels = labels.to(lm_logits.device) - # Shift so that tokens < n predict n - shift_logits = lm_logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, shift_logits.size(-1)) - shift_labels = shift_labels.view(-1) - if shard_config.enable_tensor_parallelism and shard_config.parallel_output: - loss = cross_entropy_1d( - shift_logits, - shift_labels, - process_group=shard_config.tensor_parallel_process_group, - vocab_size=self.lm_head.out_features, - dtype=self.transformer.dtype, - ) - else: - loss = loss_fct(shift_logits, shift_labels) + loss = dist_cross_entropy( + labels, lm_logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.transformer.dtype + ) if not return_dict: output = (lm_logits,) + outputs[1:] @@ -1282,24 +1264,9 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): hidden_states = transformer_outputs[0] lm_logits = self.lm_head(hidden_states) - - loss = None - if labels is not None: - # move labels to correct device to enable model parallelism - labels = labels.to(lm_logits.device) - # Shift so that tokens < n predict n - shift_logits = lm_logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - shift_logits = shift_logits.view(-1, shift_logits.size(-1)) - shift_labels = shift_labels.view(-1) - loss = cross_entropy_1d( - shift_logits, - shift_labels, - process_group=shard_config.tensor_parallel_process_group, - vocab_size=self.lm_head.out_features, - dtype=self.transformer.dtype, - ) + loss = dist_cross_entropy( + labels, lm_logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.transformer.dtype + ) if not return_dict: output = (lm_logits,) + transformer_outputs[1:] diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index bf5ce45a8..54ff8e321 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -31,7 +31,7 @@ from colossalai.shardformer.layer._operation import ( ) from colossalai.shardformer.shard import ShardConfig -from ..layer import ColoAttention, cross_entropy_1d +from ..layer import ColoAttention, dist_cross_entropy class LlamaPipelineForwards: @@ -86,13 +86,20 @@ class LlamaPipelineForwards: device = input_ids.device if input_ids is not None else inputs_embeds.device if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - hidden_states = inputs_embeds else: input_shape = hidden_states.shape[:-1] batch_size, seq_length = input_shape device = hidden_states.device + # Support SP + PP + sp_mode = shard_config.sequence_parallelism_mode + sp_group = shard_config.sequence_parallel_process_group + sp_size = shard_config.sequence_parallel_size + if sp_mode == "all_to_all" and not stage_manager.is_first_stage(): + # For correct positions ids. The states will be gather along the seq dim in the attention layer later. + seq_length *= sp_size + past_seen_tokens = 0 if use_cache: # kept for BC (cache positions) if not isinstance(past_key_values, StaticCache): @@ -101,7 +108,7 @@ class LlamaPipelineForwards: if cache_position is None: if isinstance(past_key_values, StaticCache): raise ValueError("cache_position is a required argument when using StaticCache.") - cache_position = torch.arange(past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=device) + cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_length, device=device) seq_length_with_past = seq_length + past_seen_tokens @@ -118,7 +125,6 @@ class LlamaPipelineForwards: if position_ids is None: position_ids = cache_position.unsqueeze(0) - # embed positions, for the first stage, hidden_states is the input embeddings, # for the other stages, hidden_states is the output of the previous stage if shard_config.enable_flash_attention: @@ -134,6 +140,13 @@ class LlamaPipelineForwards: else: attention_mask = self._update_causal_mask(attention_mask, hidden_states, cache_position) + # Support SP + PP + if stage_manager.is_first_stage(): + if sp_mode in ["ring", "split_gather"]: + hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group) + elif sp_mode == "all_to_all": + hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group, 1 / sp_size) + if self.gradient_checkpointing and self.training and use_cache: if use_cache: logger.warning_once( @@ -196,6 +209,10 @@ class LlamaPipelineForwards: if stage_manager.is_last_stage(): hidden_states = self.norm(hidden_states) + if sp_mode == "ring" or sp_mode == "split_gather": + hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group) + elif sp_mode == "all_to_all": + hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size) # add hidden states from the last decoder layer if output_hidden_states: @@ -304,29 +321,9 @@ class LlamaPipelineForwards: if stage_manager.is_last_stage(): hidden_states = outputs[0] logits = self.lm_head(hidden_states) - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - if shard_config.enable_tensor_parallelism and shard_config.parallel_output: - new_vocab_size = logits.shape[-1] - shift_logits = shift_logits.view(-1, new_vocab_size) - loss = cross_entropy_1d( - shift_logits, - shift_labels, - process_group=shard_config.tensor_parallel_process_group, - vocab_size=self.lm_head.out_features, - dtype=self.model.dtype, - ) - else: - shift_logits = shift_logits.view(-1, self.config.vocab_size) - loss = loss_fct(shift_logits, shift_labels) + loss = dist_cross_entropy( + labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype + ) if not return_dict: output = (logits,) + outputs[1:] @@ -529,7 +526,6 @@ def get_llama_flash_attention_forward(shard_config, sp_mode=None, sp_size=None, ) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) @@ -804,24 +800,9 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): logits = self.lm_head(hidden_states) logits = logits.float() - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - new_vocab_size = logits.shape[-1] - shift_logits = shift_logits.view(-1, new_vocab_size) - loss = cross_entropy_1d( - shift_logits, - shift_labels, - process_group=shard_config.tensor_parallel_process_group, - vocab_size=self.lm_head.out_features, - dtype=self.model.dtype, - ) - + loss = dist_cross_entropy( + labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype + ) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output diff --git a/colossalai/shardformer/modeling/mistral.py b/colossalai/shardformer/modeling/mistral.py index 310c2d8e2..82e8ef5f9 100644 --- a/colossalai/shardformer/modeling/mistral.py +++ b/colossalai/shardformer/modeling/mistral.py @@ -19,7 +19,7 @@ from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.shard import ShardConfig -from ..layer import ColoAttention, cross_entropy_1d +from ..layer import ColoAttention, dist_cross_entropy logger = logging.get_logger(__name__) @@ -275,29 +275,9 @@ class MistralForwards: logits = self.lm_head(hidden_states) logits = logits.float() - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - if shard_config.enable_tensor_parallelism and shard_config.parallel_output: - new_vocab_size = logits.shape[-1] - shift_logits = shift_logits.view(-1, new_vocab_size) - loss = cross_entropy_1d( - shift_logits, - shift_labels, - process_group=shard_config.tensor_parallel_process_group, - vocab_size=self.lm_head.out_features, - dtype=self.model.dtype, - ) - else: - shift_logits = shift_logits.view(-1, self.config.vocab_size) - loss = loss_fct(shift_logits, shift_labels) + loss = dist_cross_entropy( + labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype + ) if not return_dict: output = (logits,) + outputs[1:] @@ -708,23 +688,9 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): logits = self.lm_head(hidden_states) logits = logits.float() - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - new_vocab_size = logits.shape[-1] - shift_logits = shift_logits.view(-1, new_vocab_size) - loss = cross_entropy_1d( - shift_logits, - shift_labels, - process_group=shard_config.tensor_parallel_process_group, - vocab_size=self.lm_head.out_features, - dtype=self.model.dtype, - ) + loss = dist_cross_entropy( + labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype + ) if not return_dict: output = (logits,) + outputs[1:] diff --git a/colossalai/shardformer/modeling/opt.py b/colossalai/shardformer/modeling/opt.py index b250b4976..636b46cc4 100644 --- a/colossalai/shardformer/modeling/opt.py +++ b/colossalai/shardformer/modeling/opt.py @@ -22,7 +22,7 @@ from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.layer import ColoAttention from colossalai.shardformer.shard import ShardConfig -from ..layer import cross_entropy_1d +from ..layer import dist_cross_entropy logger = logging.get_logger(__name__) @@ -330,30 +330,14 @@ class OPTPipelineForwards: ) if stage_manager.is_last_stage(): logits = self.lm_head(outputs[0]).contiguous() - loss = None - if labels is not None: - # move labels to correct device to enable model parallelism - labels = labels.to(logits.device) - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - - if shard_config.enable_tensor_parallelism and shard_config.parallel_output: - new_vocab_size = logits.shape[-1] - shift_logits = shift_logits.view(-1, new_vocab_size) - shift_labels = shift_labels.view(-1) - loss = cross_entropy_1d( - shift_logits, - shift_labels, - process_group=shard_config.tensor_parallel_process_group, - vocab_size=self.lm_head.out_features, - dtype=self.model.decoder.dtype, - ) - else: - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) + loss = dist_cross_entropy( + labels, + logits, + shard_config, + self.lm_head.out_features, + self.config.vocab_size, + self.model.decoder.dtype, + ) if not return_dict: output = (logits,) + outputs[1:] @@ -971,26 +955,9 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): ) logits = self.lm_head(outputs[0]).contiguous() - - loss = None - if labels is not None: - # move labels to correct device to enable model parallelism - labels = labels.to(logits.device) - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - new_vocab_size = logits.shape[-1] - shift_logits = shift_logits.view(-1, new_vocab_size) - loss = cross_entropy_1d( - shift_logits, - shift_labels, - process_group=shard_config.tensor_parallel_process_group, - vocab_size=self.lm_head.out_features, - dtype=self.model.decoder.dtype, - ) + loss = dist_cross_entropy( + labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.decoder.dtype + ) if not return_dict: output = (logits,) + outputs[1:] diff --git a/colossalai/shardformer/modeling/qwen2.py b/colossalai/shardformer/modeling/qwen2.py index 11c26822f..0f253730d 100644 --- a/colossalai/shardformer/modeling/qwen2.py +++ b/colossalai/shardformer/modeling/qwen2.py @@ -32,7 +32,7 @@ from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.shard import ShardConfig -from ..layer import ColoAttention, cross_entropy_1d +from ..layer import ColoAttention, dist_cross_entropy class Qwen2PipelineForwards: @@ -317,25 +317,9 @@ class Qwen2PipelineForwards: if stage_manager.is_last_stage(): hidden_states = outputs[0] logits = self.lm_head(hidden_states) - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - if shard_config.enable_tensor_parallelism: - new_vocab_size = logits.shape[-1] - shift_logits = shift_logits.view(-1, new_vocab_size) - loss = cross_entropy_1d( - shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group - ) - else: - shift_logits = shift_logits.view(-1, self.config.vocab_size) - loss = loss_fct(shift_logits, shift_labels) + loss = dist_cross_entropy( + labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, logits.dtype + ) if not return_dict: output = (logits,) + outputs[1:] @@ -737,26 +721,9 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): hidden_states = outputs[0] logits = self.lm_head(hidden_states) logits = logits.float() - - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - if shard_config.enable_tensor_parallelism: - new_vocab_size = logits.shape[-1] - shift_logits = shift_logits.view(-1, new_vocab_size) - loss = cross_entropy_1d( - shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group - ) - else: - shift_logits = shift_logits.view(-1, self.config.vocab_size) - loss = loss_fct(shift_logits, shift_labels) + loss = dist_cross_entropy( + labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, logits.dtype + ) if not return_dict: output = (logits,) + outputs[1:] diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 85ec6717d..36491b4b5 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -1,4 +1,3 @@ -import warnings from functools import partial from typing import Callable, Dict, List, Union @@ -66,13 +65,6 @@ class LlamaPolicy(Policy): else: norm_cls = RMSNorm - if self.pipeline_stage_manager is not None: - self.shard_config.enable_sequence_parallelism = False - self.shard_config.enable_sequence_overlap = False - self.shard_config.sequence_parallelism_mode = None - warnings.warn( - f"For llama, sequence parallelism is currently not compatible with pipeline parallelism, set to be False" - ) sp_mode = self.shard_config.sequence_parallelism_mode or None sp_size = self.shard_config.sequence_parallel_size or None sp_group = self.shard_config.sequence_parallel_process_group or None diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 8fe18f69b..88e54176b 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -59,10 +59,12 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, if ( booster.plugin.zero_stage in [1, 2] and booster.plugin.shard_config.enable_sequence_parallelism + and booster.plugin.shard_config.pipeline_stage_manager is None and booster.plugin.shard_config.sequence_parallelism_mode == "all_to_all" ): + master2working = sharded_optimizer.get_master_to_working_map() for p1, p2 in zip(llama_model.parameters(), sharded_optimizer._master_param_groups_of_current_rank[0]): - working_p = sharded_optimizer.master_to_working_param[id(p2)] + working_p = master2working[id(p2)] grads = sharded_optimizer.get_partitioned_gradients_by_param_id(0, id(working_p)) grad_index = ( 0 @@ -146,6 +148,19 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, @parameterize( "test_config", [ + { # Ulysess + Flash attention + "tp_size": 1, + "pp_size": 2, + "sp_size": 2, + "num_microbatches": 2, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "all_to_all", + "enable_flash_attention": True, + "use_lazy_init": True, + "zero_stage": 0, + "precision": "fp16", + "initial_scale": 1, + }, { # Test ring + Flash attention "tp_size": 2, "pp_size": 1, @@ -159,19 +174,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "precision": "fp16", "initial_scale": 1, }, - { # Ulysess + Flash attention - "tp_size": 1, - "pp_size": 2, - "sp_size": 2, - "num_microbatches": 2, - "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "all_to_all", - "enable_flash_attention": True, - "use_lazy_init": True, - "zero_stage": 1, - "precision": "fp16", - "initial_scale": 1, - }, { "tp_size": 1, "pp_size": 1, @@ -245,7 +247,6 @@ def run_llama_test(test_config): except Exception as e: print(f"Failed config: {test_config}") raise e - clear_layout_converter() Randomizer.reset_index() torch.cuda.empty_cache()