Merge pull request #6298 from wangbluo/upgrade_command

upgrade command
This commit is contained in:
Hanks 2025-05-22 14:21:58 +08:00 committed by GitHub
commit 6a29abdefd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 65 additions and 134 deletions

View File

@ -1,19 +1,19 @@
import math
import warnings
from typing import List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.models.cohere.modeling_cohere import (
CohereAttention,
CohereForCausalLM,
CohereModel,
StaticCache,
apply_rotary_pos_emb,
repeat_kv,
)
from transformers.processing_utils import Unpack
from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager
@ -27,6 +27,8 @@ _SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring"]
_SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring", "ring_attn"]
logger = logging.get_logger(__name__)
class CommandPipelineForwards:
"""
@ -37,22 +39,23 @@ class CommandPipelineForwards:
@staticmethod
def command_model_forward(
self: CohereModel,
input_ids: torch.LongTensor = None,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
force_sp_output_gather: bool = True,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
):
logger = logging.get_logger(__name__)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
@ -67,8 +70,6 @@ class CommandPipelineForwards:
)
use_cache = False
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if stage_manager.is_first_stage():
if input_ids is not None and inputs_embeds is not None:
@ -122,6 +123,7 @@ class CommandPipelineForwards:
# embed positions, for the first stage, hidden_states is the input embeddings,
# for the other stages, hidden_states is the output of the previous stage
shard_config.enable_flash_attention = True
if shard_config.enable_flash_attention:
# in this case, attention_mask is a dict rather than a tensor
mask_shape = (batch_size, 1, seq_length, seq_length_with_past)
@ -133,7 +135,8 @@ class CommandPipelineForwards:
is_causal=True,
)
else:
attention_mask = self._update_causal_mask(attention_mask, hidden_states, cache_position)
# v4.51.3 transformers attention_mask calculation
attention_mask = self._update_causal_mask(attention_mask, hidden_states, cache_position, past_key_values)
if self.gradient_checkpointing and self.training and use_cache:
if use_cache:
@ -163,6 +166,8 @@ class CommandPipelineForwards:
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = None
# v4.51.3 transformers position_embeddings calculation
position_embeddings = self.rotary_emb(hidden_states, position_ids)
start_idx, end_idx = stage_index[0], stage_index[1]
num_ckpt_layers = 0
@ -193,6 +198,7 @@ class CommandPipelineForwards:
output_attentions,
use_cache,
cache_position,
position_embeddings,
)
else:
layer_outputs = decoder_layer(
@ -203,6 +209,7 @@ class CommandPipelineForwards:
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
hidden_states = layer_outputs[0]
@ -224,17 +231,6 @@ class CommandPipelineForwards:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if stage_manager.is_last_stage():
if not return_dict:
return tuple(
v
for v in [
hidden_states,
next_cache,
all_hidden_states,
all_self_attns,
]
if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
@ -352,27 +348,22 @@ class CommandPipelineForwards:
def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None):
def forward(
self,
self: CohereAttention,
hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if sp_mode is not None:
assert sp_mode in _SUPPORTED_SP_MODE, f"SP mode {sp_mode} is not supported by {type(self)} yet"
assert (sp_size is not None) and (
sp_group is not None
), "Must specify sp_size and sp_group for sequence parallel"
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)
bsz, q_len, _ = hidden_states.size()
# sp: modify sp_len when sequence parallel mode is ring
if sp_mode in ["split_gather", "ring"]:
q_len *= sp_size
@ -388,60 +379,36 @@ def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None,
value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication)
bsz, q_len, _ = query_states.size()
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
cos, sin = position_embeddings
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = None
shard_config.enable_flash_attention = True
if shard_config.enable_flash_attention:
assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict."
attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask)
else:
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
# attn_weights and attn_output calculation is modified on the v4.51.3 of transformers.models.cohere.modeling_cohere.CohereAttention.forward.
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
dropout = 0.0 if not self.training else self.attention_dropout
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.transpose(1, 2).contiguous()
# sp: all-to-all comminucation when introducing sequence parallel
@ -451,13 +418,11 @@ def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None,
attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication
)
else:
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
return attn_output, attn_weights
return forward
@ -467,24 +432,23 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode
def forward(
self,
input_ids: torch.LongTensor = None,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
force_sp_output_gather: bool = True,
) -> Union[Tuple, BaseModelOutputWithPast]:
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> BaseModelOutputWithPast:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if (input_ids is None) ^ (inputs_embeds is not None):
@ -516,6 +480,8 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
shard_config.enable_flash_attention = True
# in this case, attention_mask is a dict rather than a tensor
if shard_config.enable_flash_attention:
mask_shape = (inputs_embeds.shape[0], 1, past_seen_tokens + seq_len, past_seen_tokens + seq_len)
@ -527,7 +493,8 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode
is_causal=True,
)
else:
attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
# v4.51.3 transformers attention_mask calculation
attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values)
if sp_mode in ["ring", "split_gather"]:
inputs_embeds = split_forward_gather_backward(
@ -544,6 +511,9 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode
all_self_attns = () if output_attentions else None
next_decoder_cache = None
# v4.51.3 transformers position_embeddings calculation
position_embeddings = self.rotary_emb(hidden_states, position_ids)
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
@ -557,8 +527,8 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode
output_attentions,
use_cache,
cache_position,
position_embeddings,
)
else:
layer_outputs = decoder_layer(
hidden_states,
@ -568,16 +538,11 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# Cases that don't support parallelizing cross entropy computation along sequence
@ -594,8 +559,6 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode
next_cache = (
next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache
)
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,

View File

@ -6,8 +6,6 @@ from torch import Tensor
from torch.nn import Module
from colossalai.shardformer.layer import (
FusedLayerNorm,
LayerNorm,
Linear1D_Col,
Linear1D_Row,
LinearWithGradAccum,
@ -38,18 +36,13 @@ class CommandPolicy(Policy):
return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
from transformers.models.cohere.modeling_cohere import (
CohereAttention,
CohereDecoderLayer,
CohereFlashAttention2,
CohereModel,
CohereSdpaAttention,
)
from transformers.models.cohere.modeling_cohere import CohereAttention, CohereDecoderLayer, CohereModel
# The eager, flash_attention_2, sdpa will all be passed to CohereAttention in v4.51.3 transformers.
ATTN_IMPLEMENTATION = {
"eager": CohereAttention,
"flash_attention_2": CohereFlashAttention2,
"sdpa": CohereSdpaAttention,
"flash_attention_2": CohereAttention,
"sdpa": CohereAttention,
}
policy = {}
@ -61,15 +54,11 @@ class CommandPolicy(Policy):
if self.tie_weight:
embedding_cls = PaddingEmbedding
if self.shard_config.enable_fused_normalization:
norm_cls = FusedLayerNorm
else:
norm_cls = LayerNorm
# CohereLayerNorm has no bias in v4.51.3 transformers, so we don't replace it.
sp_mode = self.shard_config.sequence_parallelism_mode or None
sp_size = self.shard_config.sequence_parallel_size or None
sp_group = self.shard_config.sequence_parallel_process_group or None
sp_partial_derived = sp_mode in ["split_gather", "ring"]
if sp_mode == "ring_attn" and not self.is_causal:
raise ValueError("Ring attention is only meant for causal language modeling.")
@ -86,14 +75,15 @@ class CommandPolicy(Policy):
policy[attn_cls] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement,
)
self.append_or_create_method_replacement(
description={
"forward": get_command_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group),
},
policy=policy,
target_key=attn_cls,
)
if self.shard_config.enable_flash_attention or self.shard_config.enable_sequence_parallelism:
self.append_or_create_method_replacement(
description={
"forward": get_command_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group),
},
policy=policy,
target_key=attn_cls,
)
if self.pipeline_stage_manager is None:
self.append_or_create_method_replacement(
description={
@ -280,29 +270,6 @@ class CommandPolicy(Policy):
target_key=CohereModel,
)
# optimization configuration
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="input_layernorm",
target_module=norm_cls,
kwargs={"sp_partial_derived": sp_partial_derived},
),
],
policy=policy,
target_key=CohereDecoderLayer,
)
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="norm",
target_module=norm_cls,
kwargs={"sp_partial_derived": sp_partial_derived},
),
policy=policy,
target_key=CohereModel,
)
return policy
def postprocess(self):
@ -349,6 +316,7 @@ class CommandPolicy(Policy):
stage_manager = self.pipeline_stage_manager
held_layers = []
held_layers.append(module.rotary_emb)
if stage_manager.is_interleave:
assert stage_manager.num_model_chunks is not None
layers_per_stage = stage_manager.distribute_layers(len(module.layers))

View File

@ -218,7 +218,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"pp_size": 1,
"num_microbatches": 1,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "split_gather",
"sequence_parallelism_mode": "all_to_all",
"enable_flash_attention": False,
"use_lazy_init": True,
"precision": "fp16",