upgrade command

This commit is contained in:
wangbluo 2025-05-08 16:06:05 +08:00
parent 46ed5d856b
commit a9bb7cb943
3 changed files with 58 additions and 102 deletions

View File

@ -1,9 +1,8 @@
import math import math
import warnings import warnings
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union, Callable
import torch import torch
import torch.utils.checkpoint
from torch import nn from torch import nn
from transformers.cache_utils import Cache, DynamicCache from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
@ -15,6 +14,12 @@ from transformers.models.cohere.modeling_cohere import (
repeat_kv, repeat_kv,
) )
from transformers.utils import logging from transformers.utils import logging
from transformers.processing_utils import Unpack
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.models.cohere.modeling_cohere import eager_attention_forward
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from functools import partial
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer._operation import all_to_all_comm, split_forward_gather_backward from colossalai.shardformer.layer._operation import all_to_all_comm, split_forward_gather_backward
@ -27,6 +32,7 @@ _SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring"]
_SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring", "ring_attn"] _SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring", "ring_attn"]
logger = logging.get_logger(__name__)
class CommandPipelineForwards: class CommandPipelineForwards:
""" """
@ -37,22 +43,23 @@ class CommandPipelineForwards:
@staticmethod @staticmethod
def command_model_forward( def command_model_forward(
self: CohereModel, self: CohereModel,
input_ids: torch.LongTensor = None, input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None, past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
stage_manager: Optional[PipelineStageManager] = None, stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None, hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None, stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None, shard_config: ShardConfig = None,
force_sp_output_gather: bool = True, force_sp_output_gather: bool = True,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
): ):
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
@ -67,8 +74,6 @@ class CommandPipelineForwards:
) )
use_cache = False use_cache = False
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds # retrieve input_ids and inputs_embeds
if stage_manager.is_first_stage(): if stage_manager.is_first_stage():
if input_ids is not None and inputs_embeds is not None: if input_ids is not None and inputs_embeds is not None:
@ -133,7 +138,7 @@ class CommandPipelineForwards:
is_causal=True, is_causal=True,
) )
else: else:
attention_mask = self._update_causal_mask(attention_mask, hidden_states, cache_position) attention_mask = self._update_causal_mask(attention_mask, hidden_states, cache_position, past_key_values)
if self.gradient_checkpointing and self.training and use_cache: if self.gradient_checkpointing and self.training and use_cache:
if use_cache: if use_cache:
@ -164,6 +169,8 @@ class CommandPipelineForwards:
all_self_attns = () if output_attentions else None all_self_attns = () if output_attentions else None
next_decoder_cache = None next_decoder_cache = None
position_embeddings = self.rotary_emb(hidden_states, position_ids)
start_idx, end_idx = stage_index[0], stage_index[1] start_idx, end_idx = stage_index[0], stage_index[1]
num_ckpt_layers = 0 num_ckpt_layers = 0
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
@ -193,6 +200,7 @@ class CommandPipelineForwards:
output_attentions, output_attentions,
use_cache, use_cache,
cache_position, cache_position,
position_embeddings
) )
else: else:
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
@ -203,6 +211,7 @@ class CommandPipelineForwards:
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
cache_position=cache_position, cache_position=cache_position,
position_embeddings=position_embeddings
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
@ -224,17 +233,6 @@ class CommandPipelineForwards:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None next_cache = next_decoder_cache if use_cache else None
if stage_manager.is_last_stage(): if stage_manager.is_last_stage():
if not return_dict:
return tuple(
v
for v in [
hidden_states,
next_cache,
all_hidden_states,
all_self_attns,
]
if v is not None
)
return BaseModelOutputWithPast( return BaseModelOutputWithPast(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
past_key_values=next_cache, past_key_values=next_cache,
@ -350,29 +348,25 @@ class CommandPipelineForwards:
return {"hidden_states": hidden_states} return {"hidden_states": hidden_states}
def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None): def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None, past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs, **kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if sp_mode is not None: if sp_mode is not None:
assert sp_mode in _SUPPORTED_SP_MODE, f"SP mode {sp_mode} is not supported by {type(self)} yet" assert sp_mode in _SUPPORTED_SP_MODE, f"SP mode {sp_mode} is not supported by {type(self)} yet"
assert (sp_size is not None) and ( assert (sp_size is not None) and (
sp_group is not None sp_group is not None
), "Must specify sp_size and sp_group for sequence parallel" ), "Must specify sp_size and sp_group for sequence parallel"
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
# sp: modify sp_len when sequence parallel mode is ring # sp: modify sp_len when sequence parallel mode is ring
if sp_mode in ["split_gather", "ring"]: if sp_mode in ["split_gather", "ring"]:
q_len *= sp_size q_len *= sp_size
@ -388,9 +382,9 @@ def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None,
value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication) value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication)
bsz, q_len, _ = query_states.size() bsz, q_len, _ = query_states.size()
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2] kv_seq_len = key_states.shape[-2]
if past_key_value is not None: if past_key_value is not None:
@ -403,7 +397,8 @@ def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None,
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, position_ids) cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None: if past_key_value is not None:
@ -413,11 +408,14 @@ def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None,
# repeat k/v heads if n_kv_heads < n_heads # repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups) key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = None
if shard_config.enable_flash_attention: if shard_config.enable_flash_attention:
assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict." assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict."
attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask) attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask)
else: else:
# to be fixed:
# precision issue
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
@ -451,40 +449,36 @@ def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None,
attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication
) )
else: else:
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
attn_output = self.o_proj(attn_output) attn_output = self.o_proj(attn_output)
if not output_attentions: return attn_output, attn_weights
attn_weights = None
return attn_output, attn_weights, past_key_value
return forward return forward
def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None): def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None):
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
def forward( def forward(
self, self,
input_ids: torch.LongTensor = None, input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None, past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
force_sp_output_gather: bool = True, force_sp_output_gather: bool = True,
) -> Union[Tuple, BaseModelOutputWithPast]: **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> BaseModelOutputWithPast:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
) )
use_cache = use_cache if use_cache is not None else self.config.use_cache use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds # retrieve input_ids and inputs_embeds
if (input_ids is None) ^ (inputs_embeds is not None): if (input_ids is None) ^ (inputs_embeds is not None):
@ -527,7 +521,7 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode
is_causal=True, is_causal=True,
) )
else: else:
attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position) attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values)
if sp_mode in ["ring", "split_gather"]: if sp_mode in ["ring", "split_gather"]:
inputs_embeds = split_forward_gather_backward( inputs_embeds = split_forward_gather_backward(
@ -544,6 +538,8 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode
all_self_attns = () if output_attentions else None all_self_attns = () if output_attentions else None
next_decoder_cache = None next_decoder_cache = None
position_embeddings = self.rotary_emb(hidden_states, position_ids)
for decoder_layer in self.layers: for decoder_layer in self.layers:
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
@ -557,6 +553,7 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode
output_attentions, output_attentions,
use_cache, use_cache,
cache_position, cache_position,
position_embeddings
) )
else: else:
@ -568,16 +565,11 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
cache_position=cache_position, cache_position=cache_position,
position_embeddings=position_embeddings
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
# Cases that don't support parallelizing cross entropy computation along sequence # Cases that don't support parallelizing cross entropy computation along sequence
@ -594,8 +586,6 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode
next_cache = ( next_cache = (
next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache
) )
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast( return BaseModelOutputWithPast(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,

View File

@ -41,15 +41,13 @@ class CommandPolicy(Policy):
from transformers.models.cohere.modeling_cohere import ( from transformers.models.cohere.modeling_cohere import (
CohereAttention, CohereAttention,
CohereDecoderLayer, CohereDecoderLayer,
CohereFlashAttention2,
CohereModel, CohereModel,
CohereSdpaAttention,
) )
ATTN_IMPLEMENTATION = { ATTN_IMPLEMENTATION = {
"eager": CohereAttention, "eager": CohereAttention,
"flash_attention_2": CohereFlashAttention2, "flash_attention_2": CohereAttention,
"sdpa": CohereSdpaAttention, "sdpa": CohereAttention,
} }
policy = {} policy = {}
@ -61,11 +59,6 @@ class CommandPolicy(Policy):
if self.tie_weight: if self.tie_weight:
embedding_cls = PaddingEmbedding embedding_cls = PaddingEmbedding
if self.shard_config.enable_fused_normalization:
norm_cls = FusedLayerNorm
else:
norm_cls = LayerNorm
sp_mode = self.shard_config.sequence_parallelism_mode or None sp_mode = self.shard_config.sequence_parallelism_mode or None
sp_size = self.shard_config.sequence_parallel_size or None sp_size = self.shard_config.sequence_parallel_size or None
sp_group = self.shard_config.sequence_parallel_process_group or None sp_group = self.shard_config.sequence_parallel_process_group or None
@ -280,29 +273,6 @@ class CommandPolicy(Policy):
target_key=CohereModel, target_key=CohereModel,
) )
# optimization configuration
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="input_layernorm",
target_module=norm_cls,
kwargs={"sp_partial_derived": sp_partial_derived},
),
],
policy=policy,
target_key=CohereDecoderLayer,
)
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="norm",
target_module=norm_cls,
kwargs={"sp_partial_derived": sp_partial_derived},
),
policy=policy,
target_key=CohereModel,
)
return policy return policy
def postprocess(self): def postprocess(self):
@ -349,6 +319,7 @@ class CommandPolicy(Policy):
stage_manager = self.pipeline_stage_manager stage_manager = self.pipeline_stage_manager
held_layers = [] held_layers = []
held_layers.append(module.rotary_emb)
if stage_manager.is_interleave: if stage_manager.is_interleave:
assert stage_manager.num_model_chunks is not None assert stage_manager.num_model_chunks is not None
layers_per_stage = stage_manager.distribute_layers(len(module.layers)) layers_per_stage = stage_manager.distribute_layers(len(module.layers))

View File

@ -213,17 +213,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"precision": "fp16", "precision": "fp16",
"initial_scale": 1, "initial_scale": 1,
}, },
{
"tp_size": 4,
"pp_size": 1,
"num_microbatches": 1,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "split_gather",
"enable_flash_attention": False,
"use_lazy_init": True,
"precision": "fp16",
"initial_scale": 1,
},
{ {
"tp_size": 1, "tp_size": 1,
"pp_size": 1, "pp_size": 1,
@ -231,6 +221,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"num_microbatches": 1, "num_microbatches": 1,
"enable_sequence_parallelism": True, "enable_sequence_parallelism": True,
"sequence_parallelism_mode": "all_to_all", "sequence_parallelism_mode": "all_to_all",
"enable_flash_attention": True,
"use_lazy_init": True, "use_lazy_init": True,
"zero_stage": 2, "zero_stage": 2,
"precision": "fp16", "precision": "fp16",
@ -241,6 +232,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"pp_size": 2, "pp_size": 2,
"num_microbatches": 2, "num_microbatches": 2,
"enable_all_optimization": True, "enable_all_optimization": True,
"enable_flash_attention": True,
"use_lazy_init": True, "use_lazy_init": True,
"precision": "fp16", "precision": "fp16",
"initial_scale": 1, "initial_scale": 1,
@ -252,6 +244,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"pp_size": 2, "pp_size": 2,
"num_microbatches": 4, "num_microbatches": 4,
"use_lazy_init": False, "use_lazy_init": False,
"enable_flash_attention": True,
"precision": "fp32", "precision": "fp32",
"enable_gradient_checkpointing": True, "enable_gradient_checkpointing": True,
"gradient_checkpoint_config": PipelineGradientCheckpointConfig(num_ckpt_layers_per_stage=[4, 0]), "gradient_checkpoint_config": PipelineGradientCheckpointConfig(num_ckpt_layers_per_stage=[4, 0]),
@ -260,6 +253,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"tp_size": 2, "tp_size": 2,
"pp_size": 1, "pp_size": 1,
"enable_all_optimization": True, "enable_all_optimization": True,
"enable_flash_attention": True,
"use_lazy_init": True, "use_lazy_init": True,
"zero_stage": 2, "zero_stage": 2,
"precision": "fp16", "precision": "fp16",
@ -270,6 +264,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"pp_size": 2, "pp_size": 2,
"num_microbatches": 2, "num_microbatches": 2,
"enable_all_optimization": True, "enable_all_optimization": True,
"enable_flash_attention": True,
"use_lazy_init": True, "use_lazy_init": True,
"zero_stage": 1, "zero_stage": 1,
"precision": "fp16", "precision": "fp16",